summaryrefslogtreecommitdiffstats
path: root/net/mctp
diff options
context:
space:
mode:
Diffstat (limited to 'net/mctp')
-rw-r--r--net/mctp/af_mctp.c14
-rw-r--r--net/mctp/route.c118
2 files changed, 94 insertions, 38 deletions
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index a9526ac29dff..2767d548736b 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -263,21 +263,21 @@ static void mctp_sk_unhash(struct sock *sk)
/* remove tag allocations */
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
- hlist_del_rcu(&key->sklist);
- hlist_del_rcu(&key->hlist);
+ hlist_del(&key->sklist);
+ hlist_del(&key->hlist);
- spin_lock(&key->reasm_lock);
+ spin_lock(&key->lock);
if (key->reasm_head)
kfree_skb(key->reasm_head);
key->reasm_head = NULL;
key->reasm_dead = true;
- spin_unlock(&key->reasm_lock);
+ key->valid = false;
+ spin_unlock(&key->lock);
- kfree_rcu(key, rcu);
+ /* key is no longer on the lookup lists, unref */
+ mctp_key_unref(key);
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
-
- synchronize_rcu();
}
static struct proto mctp_proto = {
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 224fd25b3678..b2243b150e71 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -83,25 +83,43 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
return true;
}
+/* returns a key (with key->lock held, and refcounted), or NULL if no such
+ * key exists.
+ */
static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
- mctp_eid_t peer)
+ mctp_eid_t peer,
+ unsigned long *irqflags)
+ __acquires(&key->lock)
{
struct mctp_sk_key *key, *ret;
+ unsigned long flags;
struct mctp_hdr *mh;
u8 tag;
- WARN_ON(!rcu_read_lock_held());
-
mh = mctp_hdr(skb);
tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
ret = NULL;
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
- hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) {
- if (mctp_key_match(key, mh->dest, peer, tag)) {
+ hlist_for_each_entry(key, &net->mctp.keys, hlist) {
+ if (!mctp_key_match(key, mh->dest, peer, tag))
+ continue;
+
+ spin_lock(&key->lock);
+ if (key->valid) {
+ refcount_inc(&key->refs);
ret = key;
break;
}
+ spin_unlock(&key->lock);
+ }
+
+ if (ret) {
+ spin_unlock(&net->mctp.keys_lock);
+ *irqflags = flags;
+ } else {
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
}
return ret;
@@ -121,11 +139,19 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
key->local_addr = local;
key->tag = tag;
key->sk = &msk->sk;
- spin_lock_init(&key->reasm_lock);
+ key->valid = true;
+ spin_lock_init(&key->lock);
+ refcount_set(&key->refs, 1);
return key;
}
+void mctp_key_unref(struct mctp_sk_key *key)
+{
+ if (refcount_dec_and_test(&key->refs))
+ kfree(key);
+}
+
static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
{
struct net *net = sock_net(&msk->sk);
@@ -138,12 +164,17 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
key->tag)) {
- rc = -EEXIST;
- break;
+ spin_lock(&tmp->lock);
+ if (tmp->valid)
+ rc = -EEXIST;
+ spin_unlock(&tmp->lock);
+ if (rc)
+ break;
}
}
if (!rc) {
+ refcount_inc(&key->refs);
hlist_add_head(&key->hlist, &net->mctp.keys);
hlist_add_head(&key->sklist, &msk->keys);
}
@@ -153,28 +184,35 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
return rc;
}
-/* Must be called with key->reasm_lock, which it will release. Will schedule
- * the key for an RCU free.
+/* We're done with the key; unset valid and remove from lists. There may still
+ * be outstanding refs on the key though...
*/
static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
unsigned long flags)
- __releases(&key->reasm_lock)
+ __releases(&key->lock)
{
struct sk_buff *skb;
skb = key->reasm_head;
key->reasm_head = NULL;
key->reasm_dead = true;
- spin_unlock_irqrestore(&key->reasm_lock, flags);
+ key->valid = false;
+ spin_unlock_irqrestore(&key->lock, flags);
spin_lock_irqsave(&net->mctp.keys_lock, flags);
- hlist_del_rcu(&key->hlist);
- hlist_del_rcu(&key->sklist);
+ hlist_del(&key->hlist);
+ hlist_del(&key->sklist);
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
- kfree_rcu(key, rcu);
+
+ /* one unref for the lists */
+ mctp_key_unref(key);
+
+ /* and one for the local reference */
+ mctp_key_unref(key);
if (skb)
kfree_skb(skb);
+
}
static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
@@ -248,8 +286,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
rcu_read_lock();
- /* lookup socket / reasm context, exactly matching (src,dest,tag) */
- key = mctp_lookup_key(net, skb, mh->src);
+ /* lookup socket / reasm context, exactly matching (src,dest,tag).
+ * we hold a ref on the key, and key->lock held.
+ */
+ key = mctp_lookup_key(net, skb, mh->src, &f);
if (flags & MCTP_HDR_FLAG_SOM) {
if (key) {
@@ -260,10 +300,12 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* key for reassembly - we'll create a more specific
* one for future packets if required (ie, !EOM).
*/
- key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+ key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
if (key) {
msk = container_of(key->sk,
struct mctp_sock, sk);
+ spin_unlock_irqrestore(&key->lock, f);
+ mctp_key_unref(key);
key = NULL;
}
}
@@ -282,11 +324,11 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (flags & MCTP_HDR_FLAG_EOM) {
sock_queue_rcv_skb(&msk->sk, skb);
if (key) {
- spin_lock_irqsave(&key->reasm_lock, f);
/* we've hit a pending reassembly; not much we
* can do but drop it
*/
__mctp_key_unlock_drop(key, net, f);
+ key = NULL;
}
rc = 0;
goto out_unlock;
@@ -303,7 +345,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
goto out_unlock;
}
- /* we can queue without the reasm lock here, as the
+ /* we can queue without the key lock here, as the
* key isn't observable yet
*/
mctp_frag_queue(key, skb);
@@ -318,17 +360,17 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (rc)
kfree(key);
- } else {
- /* existing key: start reassembly */
- spin_lock_irqsave(&key->reasm_lock, f);
+ /* we don't need to release key->lock on exit */
+ key = NULL;
+ } else {
if (key->reasm_head || key->reasm_dead) {
/* duplicate start? drop everything */
__mctp_key_unlock_drop(key, net, f);
rc = -EEXIST;
+ key = NULL;
} else {
rc = mctp_frag_queue(key, skb);
- spin_unlock_irqrestore(&key->reasm_lock, f);
}
}
@@ -337,8 +379,6 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* using the message-specific key
*/
- spin_lock_irqsave(&key->reasm_lock, f);
-
/* we need to be continuing an existing reassembly... */
if (!key->reasm_head)
rc = -EINVAL;
@@ -352,8 +392,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
sock_queue_rcv_skb(key->sk, key->reasm_head);
key->reasm_head = NULL;
__mctp_key_unlock_drop(key, net, f);
- } else {
- spin_unlock_irqrestore(&key->reasm_lock, f);
+ key = NULL;
}
} else {
@@ -363,6 +402,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
out_unlock:
rcu_read_unlock();
+ if (key) {
+ spin_unlock_irqrestore(&key->lock, f);
+ mctp_key_unref(key);
+ }
out:
if (rc)
kfree_skb(skb);
@@ -459,6 +502,7 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
*/
hlist_add_head_rcu(&key->hlist, &mns->keys);
hlist_add_head_rcu(&key->sklist, &msk->keys);
+ refcount_inc(&key->refs);
}
/* Allocate a locally-owned tag value for (saddr, daddr), and reserve
@@ -492,14 +536,26 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk,
* tags. If we find a conflict, clear that bit from tagbits
*/
hlist_for_each_entry(tmp, &mns->keys, hlist) {
+ /* We can check the lookup fields (*_addr, tag) without the
+ * lock held, they don't change over the lifetime of the key.
+ */
+
/* if we don't own the tag, it can't conflict */
if (tmp->tag & MCTP_HDR_FLAG_TO)
continue;
- if ((tmp->peer_addr == daddr ||
- tmp->peer_addr == MCTP_ADDR_ANY) &&
- tmp->local_addr == saddr)
+ if (!((tmp->peer_addr == daddr ||
+ tmp->peer_addr == MCTP_ADDR_ANY) &&
+ tmp->local_addr == saddr))
+ continue;
+
+ spin_lock(&tmp->lock);
+ /* key must still be valid. If we find a match, clear the
+ * potential tag value
+ */
+ if (tmp->valid)
tagbits &= ~(1 << tmp->tag);
+ spin_unlock(&tmp->lock);
if (!tagbits)
break;