diff options
Diffstat (limited to 'net/rxrpc/ar-call.c')
-rw-r--r-- | net/rxrpc/ar-call.c | 213 |
1 files changed, 205 insertions, 8 deletions
diff --git a/net/rxrpc/ar-call.c b/net/rxrpc/ar-call.c index a3bbb360a3f9..a9e05db0f5d5 100644 --- a/net/rxrpc/ar-call.c +++ b/net/rxrpc/ar-call.c @@ -12,10 +12,22 @@ #include <linux/slab.h> #include <linux/module.h> #include <linux/circ_buf.h> +#include <linux/hashtable.h> +#include <linux/spinlock_types.h> #include <net/sock.h> #include <net/af_rxrpc.h> #include "ar-internal.h" +/* + * Maximum lifetime of a call (in jiffies). + */ +unsigned rxrpc_max_call_lifetime = 60 * HZ; + +/* + * Time till dead call expires after last use (in jiffies). + */ +unsigned rxrpc_dead_call_expiry = 2 * HZ; + const char *const rxrpc_call_states[] = { [RXRPC_CALL_CLIENT_SEND_REQUEST] = "ClSndReq", [RXRPC_CALL_CLIENT_AWAIT_REPLY] = "ClAwtRpl", @@ -38,8 +50,6 @@ const char *const rxrpc_call_states[] = { struct kmem_cache *rxrpc_call_jar; LIST_HEAD(rxrpc_calls); DEFINE_RWLOCK(rxrpc_call_lock); -static unsigned int rxrpc_call_max_lifetime = 60; -static unsigned int rxrpc_dead_call_timeout = 2; static void rxrpc_destroy_call(struct work_struct *work); static void rxrpc_call_life_expired(unsigned long _call); @@ -47,6 +57,145 @@ static void rxrpc_dead_call_expired(unsigned long _call); static void rxrpc_ack_time_expired(unsigned long _call); static void rxrpc_resend_time_expired(unsigned long _call); +static DEFINE_SPINLOCK(rxrpc_call_hash_lock); +static DEFINE_HASHTABLE(rxrpc_call_hash, 10); + +/* + * Hash function for rxrpc_call_hash + */ +static unsigned long rxrpc_call_hashfunc( + u8 clientflag, + __be32 cid, + __be32 call_id, + __be32 epoch, + __be16 service_id, + sa_family_t proto, + void *localptr, + unsigned int addr_size, + const u8 *peer_addr) +{ + const u16 *p; + unsigned int i; + unsigned long key; + u32 hcid = ntohl(cid); + + _enter(""); + + key = (unsigned long)localptr; + /* We just want to add up the __be32 values, so forcing the + * cast should be okay. + */ + key += (__force u32)epoch; + key += (__force u16)service_id; + key += (__force u32)call_id; + key += (hcid & RXRPC_CIDMASK) >> RXRPC_CIDSHIFT; + key += hcid & RXRPC_CHANNELMASK; + key += clientflag; + key += proto; + /* Step through the peer address in 16-bit portions for speed */ + for (i = 0, p = (const u16 *)peer_addr; i < addr_size >> 1; i++, p++) + key += *p; + _leave(" key = 0x%lx", key); + return key; +} + +/* + * Add a call to the hashtable + */ +static void rxrpc_call_hash_add(struct rxrpc_call *call) +{ + unsigned long key; + unsigned int addr_size = 0; + + _enter(""); + switch (call->proto) { + case AF_INET: + addr_size = sizeof(call->peer_ip.ipv4_addr); + break; + case AF_INET6: + addr_size = sizeof(call->peer_ip.ipv6_addr); + break; + default: + break; + } + key = rxrpc_call_hashfunc(call->in_clientflag, call->cid, + call->call_id, call->epoch, + call->service_id, call->proto, + call->conn->trans->local, addr_size, + call->peer_ip.ipv6_addr); + /* Store the full key in the call */ + call->hash_key = key; + spin_lock(&rxrpc_call_hash_lock); + hash_add_rcu(rxrpc_call_hash, &call->hash_node, key); + spin_unlock(&rxrpc_call_hash_lock); + _leave(""); +} + +/* + * Remove a call from the hashtable + */ +static void rxrpc_call_hash_del(struct rxrpc_call *call) +{ + _enter(""); + spin_lock(&rxrpc_call_hash_lock); + hash_del_rcu(&call->hash_node); + spin_unlock(&rxrpc_call_hash_lock); + _leave(""); +} + +/* + * Find a call in the hashtable and return it, or NULL if it + * isn't there. + */ +struct rxrpc_call *rxrpc_find_call_hash( + u8 clientflag, + __be32 cid, + __be32 call_id, + __be32 epoch, + __be16 service_id, + void *localptr, + sa_family_t proto, + const u8 *peer_addr) +{ + unsigned long key; + unsigned int addr_size = 0; + struct rxrpc_call *call = NULL; + struct rxrpc_call *ret = NULL; + + _enter(""); + switch (proto) { + case AF_INET: + addr_size = sizeof(call->peer_ip.ipv4_addr); + break; + case AF_INET6: + addr_size = sizeof(call->peer_ip.ipv6_addr); + break; + default: + break; + } + + key = rxrpc_call_hashfunc(clientflag, cid, call_id, epoch, + service_id, proto, localptr, addr_size, + peer_addr); + hash_for_each_possible_rcu(rxrpc_call_hash, call, hash_node, key) { + if (call->hash_key == key && + call->call_id == call_id && + call->cid == cid && + call->in_clientflag == clientflag && + call->service_id == service_id && + call->proto == proto && + call->local == localptr && + memcmp(call->peer_ip.ipv6_addr, peer_addr, + addr_size) == 0 && + call->epoch == epoch) { + ret = call; + break; + } + } + _leave(" = %p", ret); + return ret; +} + /* * allocate a new call */ @@ -91,7 +240,7 @@ static struct rxrpc_call *rxrpc_alloc_call(gfp_t gfp) call->rx_data_expect = 1; call->rx_data_eaten = 0; call->rx_first_oos = 0; - call->ackr_win_top = call->rx_data_eaten + 1 + RXRPC_MAXACKS; + call->ackr_win_top = call->rx_data_eaten + 1 + rxrpc_rx_window_size; call->creation_jif = jiffies; return call; } @@ -128,11 +277,31 @@ static struct rxrpc_call *rxrpc_alloc_client_call( return ERR_PTR(ret); } + /* Record copies of information for hashtable lookup */ + call->proto = rx->proto; + call->local = trans->local; + switch (call->proto) { + case AF_INET: + call->peer_ip.ipv4_addr = + trans->peer->srx.transport.sin.sin_addr.s_addr; + break; + case AF_INET6: + memcpy(call->peer_ip.ipv6_addr, + trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8, + sizeof(call->peer_ip.ipv6_addr)); + break; + } + call->epoch = call->conn->epoch; + call->service_id = call->conn->service_id; + call->in_clientflag = call->conn->in_clientflag; + /* Add the new call to the hashtable */ + rxrpc_call_hash_add(call); + spin_lock(&call->conn->trans->peer->lock); list_add(&call->error_link, &call->conn->trans->peer->error_targets); spin_unlock(&call->conn->trans->peer->lock); - call->lifetimer.expires = jiffies + rxrpc_call_max_lifetime * HZ; + call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime; add_timer(&call->lifetimer); _leave(" = %p", call); @@ -320,9 +489,12 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx, parent = *p; call = rb_entry(parent, struct rxrpc_call, conn_node); - if (call_id < call->call_id) + /* The tree is sorted in order of the __be32 value without + * turning it into host order. + */ + if ((__force u32)call_id < (__force u32)call->call_id) p = &(*p)->rb_left; - else if (call_id > call->call_id) + else if ((__force u32)call_id > (__force u32)call->call_id) p = &(*p)->rb_right; else goto old_call; @@ -347,9 +519,31 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx, list_add_tail(&call->link, &rxrpc_calls); write_unlock_bh(&rxrpc_call_lock); + /* Record copies of information for hashtable lookup */ + call->proto = rx->proto; + call->local = conn->trans->local; + switch (call->proto) { + case AF_INET: + call->peer_ip.ipv4_addr = + conn->trans->peer->srx.transport.sin.sin_addr.s_addr; + break; + case AF_INET6: + memcpy(call->peer_ip.ipv6_addr, + conn->trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8, + sizeof(call->peer_ip.ipv6_addr)); + break; + default: + break; + } + call->epoch = conn->epoch; + call->service_id = conn->service_id; + call->in_clientflag = conn->in_clientflag; + /* Add the new call to the hashtable */ + rxrpc_call_hash_add(call); + _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id); - call->lifetimer.expires = jiffies + rxrpc_call_max_lifetime * HZ; + call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime; add_timer(&call->lifetimer); _leave(" = %p {%d} [new]", call, call->debug_id); return call; @@ -533,7 +727,7 @@ void rxrpc_release_call(struct rxrpc_call *call) del_timer_sync(&call->resend_timer); del_timer_sync(&call->ack_timer); del_timer_sync(&call->lifetimer); - call->deadspan.expires = jiffies + rxrpc_dead_call_timeout * HZ; + call->deadspan.expires = jiffies + rxrpc_dead_call_expiry; add_timer(&call->deadspan); _leave(""); @@ -665,6 +859,9 @@ static void rxrpc_cleanup_call(struct rxrpc_call *call) rxrpc_put_connection(call->conn); } + /* Remove the call from the hash */ + rxrpc_call_hash_del(call); + if (call->acks_window) { _debug("kill Tx window %d", CIRC_CNT(call->acks_head, call->acks_tail, |