summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorJoanne Koong <joannelkoong@gmail.com>2022-05-19 17:18:33 -0700
committerJakub Kicinski <kuba@kernel.org>2022-05-20 18:16:24 -0700
commitd5a42de8bdbe25081f07b801d8b35f4d75a791f4 (patch)
tree7bc53e45ea3628a75c25931f9f1d5c7f64d63cd5 /net
parenteac67d83bf2553f98cfddd42cfbcfd6f9ccfc287 (diff)
downloadlinux-d5a42de8bdbe25081f07b801d8b35f4d75a791f4.tar.bz2
net: Add a second bind table hashed by port and address
We currently have one tcp bind table (bhash) which hashes by port number only. In the socket bind path, we check for bind conflicts by traversing the specified port's inet_bind2_bucket while holding the bucket's spinlock (see inet_csk_get_port() and inet_csk_bind_conflict()). In instances where there are tons of sockets hashed to the same port at different addresses, checking for a bind conflict is time-intensive and can cause softirq cpu lockups, as well as stops new tcp connections since __inet_inherit_port() also contests for the spinlock. This patch proposes adding a second bind table, bhash2, that hashes by port and ip address. Searching the bhash2 table leads to significantly faster conflict resolution and less time holding the spinlock. Signed-off-by: Joanne Koong <joannelkoong@gmail.com> Reviewed-by: Eric Dumazet <edumazet@google.com> Acked-by: Kuniyuki Iwashima <kuniyu@amazon.co.jp> Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Diffstat (limited to 'net')
-rw-r--r--net/dccp/proto.c33
-rw-r--r--net/ipv4/inet_connection_sock.c247
-rw-r--r--net/ipv4/inet_hashtables.c193
-rw-r--r--net/ipv4/tcp.c14
4 files changed, 405 insertions, 82 deletions
diff --git a/net/dccp/proto.c b/net/dccp/proto.c
index eb8e128e43e8..2e78458900f2 100644
--- a/net/dccp/proto.c
+++ b/net/dccp/proto.c
@@ -1120,6 +1120,12 @@ static int __init dccp_init(void)
SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL);
if (!dccp_hashinfo.bind_bucket_cachep)
goto out_free_hashinfo2;
+ dccp_hashinfo.bind2_bucket_cachep =
+ kmem_cache_create("dccp_bind2_bucket",
+ sizeof(struct inet_bind2_bucket), 0,
+ SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL);
+ if (!dccp_hashinfo.bind2_bucket_cachep)
+ goto out_free_bind_bucket_cachep;
/*
* Size and allocate the main established and bind bucket
@@ -1150,7 +1156,7 @@ static int __init dccp_init(void)
if (!dccp_hashinfo.ehash) {
DCCP_CRIT("Failed to allocate DCCP established hash table");
- goto out_free_bind_bucket_cachep;
+ goto out_free_bind2_bucket_cachep;
}
for (i = 0; i <= dccp_hashinfo.ehash_mask; i++)
@@ -1176,14 +1182,23 @@ static int __init dccp_init(void)
goto out_free_dccp_locks;
}
+ dccp_hashinfo.bhash2 = (struct inet_bind2_hashbucket *)
+ __get_free_pages(GFP_ATOMIC | __GFP_NOWARN, bhash_order);
+
+ if (!dccp_hashinfo.bhash2) {
+ DCCP_CRIT("Failed to allocate DCCP bind2 hash table");
+ goto out_free_dccp_bhash;
+ }
+
for (i = 0; i < dccp_hashinfo.bhash_size; i++) {
spin_lock_init(&dccp_hashinfo.bhash[i].lock);
INIT_HLIST_HEAD(&dccp_hashinfo.bhash[i].chain);
+ INIT_HLIST_HEAD(&dccp_hashinfo.bhash2[i].chain);
}
rc = dccp_mib_init();
if (rc)
- goto out_free_dccp_bhash;
+ goto out_free_dccp_bhash2;
rc = dccp_ackvec_init();
if (rc)
@@ -1207,30 +1222,38 @@ out_ackvec_exit:
dccp_ackvec_exit();
out_free_dccp_mib:
dccp_mib_exit();
+out_free_dccp_bhash2:
+ free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order);
out_free_dccp_bhash:
free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order);
out_free_dccp_locks:
inet_ehash_locks_free(&dccp_hashinfo);
out_free_dccp_ehash:
free_pages((unsigned long)dccp_hashinfo.ehash, ehash_order);
+out_free_bind2_bucket_cachep:
+ kmem_cache_destroy(dccp_hashinfo.bind2_bucket_cachep);
out_free_bind_bucket_cachep:
kmem_cache_destroy(dccp_hashinfo.bind_bucket_cachep);
out_free_hashinfo2:
inet_hashinfo2_free_mod(&dccp_hashinfo);
out_fail:
dccp_hashinfo.bhash = NULL;
+ dccp_hashinfo.bhash2 = NULL;
dccp_hashinfo.ehash = NULL;
dccp_hashinfo.bind_bucket_cachep = NULL;
+ dccp_hashinfo.bind2_bucket_cachep = NULL;
return rc;
}
static void __exit dccp_fini(void)
{
+ int bhash_order = get_order(dccp_hashinfo.bhash_size *
+ sizeof(struct inet_bind_hashbucket));
+
ccid_cleanup_builtins();
dccp_mib_exit();
- free_pages((unsigned long)dccp_hashinfo.bhash,
- get_order(dccp_hashinfo.bhash_size *
- sizeof(struct inet_bind_hashbucket)));
+ free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order);
+ free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order);
free_pages((unsigned long)dccp_hashinfo.ehash,
get_order((dccp_hashinfo.ehash_mask + 1) *
sizeof(struct inet_ehash_bucket)));
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index 53f5f956d948..c0b7e6c21360 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -117,6 +117,32 @@ bool inet_rcv_saddr_any(const struct sock *sk)
return !sk->sk_rcv_saddr;
}
+static bool use_bhash2_on_bind(const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ int addr_type;
+
+ if (sk->sk_family == AF_INET6) {
+ addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
+ return addr_type != IPV6_ADDR_ANY &&
+ addr_type != IPV6_ADDR_MAPPED;
+ }
+#endif
+ return sk->sk_rcv_saddr != htonl(INADDR_ANY);
+}
+
+static u32 get_bhash2_nulladdr_hash(const struct sock *sk, struct net *net,
+ int port)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ struct in6_addr nulladdr = {};
+
+ if (sk->sk_family == AF_INET6)
+ return ipv6_portaddr_hash(net, &nulladdr, port);
+#endif
+ return ipv4_portaddr_hash(net, 0, port);
+}
+
void inet_get_local_port_range(struct net *net, int *low, int *high)
{
unsigned int seq;
@@ -130,16 +156,71 @@ void inet_get_local_port_range(struct net *net, int *low, int *high)
}
EXPORT_SYMBOL(inet_get_local_port_range);
-static int inet_csk_bind_conflict(const struct sock *sk,
- const struct inet_bind_bucket *tb,
- bool relax, bool reuseport_ok)
+static bool bind_conflict_exist(const struct sock *sk, struct sock *sk2,
+ kuid_t sk_uid, bool relax,
+ bool reuseport_cb_ok, bool reuseport_ok)
+{
+ int bound_dev_if2;
+
+ if (sk == sk2)
+ return false;
+
+ bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
+
+ if (!sk->sk_bound_dev_if || !bound_dev_if2 ||
+ sk->sk_bound_dev_if == bound_dev_if2) {
+ if (sk->sk_reuse && sk2->sk_reuse &&
+ sk2->sk_state != TCP_LISTEN) {
+ if (!relax || (!reuseport_ok && sk->sk_reuseport &&
+ sk2->sk_reuseport && reuseport_cb_ok &&
+ (sk2->sk_state == TCP_TIME_WAIT ||
+ uid_eq(sk_uid, sock_i_uid(sk2)))))
+ return true;
+ } else if (!reuseport_ok || !sk->sk_reuseport ||
+ !sk2->sk_reuseport || !reuseport_cb_ok ||
+ (sk2->sk_state != TCP_TIME_WAIT &&
+ !uid_eq(sk_uid, sock_i_uid(sk2)))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool check_bhash2_conflict(const struct sock *sk,
+ struct inet_bind2_bucket *tb2, kuid_t sk_uid,
+ bool relax, bool reuseport_cb_ok,
+ bool reuseport_ok)
{
struct sock *sk2;
- bool reuseport_cb_ok;
- bool reuse = sk->sk_reuse;
- bool reuseport = !!sk->sk_reuseport;
- struct sock_reuseport *reuseport_cb;
+
+ sk_for_each_bound_bhash2(sk2, &tb2->owners) {
+ if (sk->sk_family == AF_INET && ipv6_only_sock(sk2))
+ continue;
+
+ if (bind_conflict_exist(sk, sk2, sk_uid, relax,
+ reuseport_cb_ok, reuseport_ok))
+ return true;
+ }
+ return false;
+}
+
+/* This should be called only when the corresponding inet_bind_bucket spinlock
+ * is held
+ */
+static int inet_csk_bind_conflict(const struct sock *sk, int port,
+ struct inet_bind_bucket *tb,
+ struct inet_bind2_bucket *tb2, /* may be null */
+ bool relax, bool reuseport_ok)
+{
+ struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
kuid_t uid = sock_i_uid((struct sock *)sk);
+ struct sock_reuseport *reuseport_cb;
+ struct inet_bind2_hashbucket *head2;
+ bool reuseport_cb_ok;
+ struct sock *sk2;
+ struct net *net;
+ int l3mdev;
+ u32 hash;
rcu_read_lock();
reuseport_cb = rcu_dereference(sk->sk_reuseport_cb);
@@ -150,40 +231,42 @@ static int inet_csk_bind_conflict(const struct sock *sk,
/*
* Unlike other sk lookup places we do not check
* for sk_net here, since _all_ the socks listed
- * in tb->owners list belong to the same net - the
- * one this bucket belongs to.
+ * in tb->owners and tb2->owners list belong
+ * to the same net
*/
- sk_for_each_bound(sk2, &tb->owners) {
- int bound_dev_if2;
+ if (!use_bhash2_on_bind(sk)) {
+ sk_for_each_bound(sk2, &tb->owners)
+ if (bind_conflict_exist(sk, sk2, uid, relax,
+ reuseport_cb_ok, reuseport_ok) &&
+ inet_rcv_saddr_equal(sk, sk2, true))
+ return true;
- if (sk == sk2)
- continue;
- bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
- if ((!sk->sk_bound_dev_if ||
- !bound_dev_if2 ||
- sk->sk_bound_dev_if == bound_dev_if2)) {
- if (reuse && sk2->sk_reuse &&
- sk2->sk_state != TCP_LISTEN) {
- if ((!relax ||
- (!reuseport_ok &&
- reuseport && sk2->sk_reuseport &&
- reuseport_cb_ok &&
- (sk2->sk_state == TCP_TIME_WAIT ||
- uid_eq(uid, sock_i_uid(sk2))))) &&
- inet_rcv_saddr_equal(sk, sk2, true))
- break;
- } else if (!reuseport_ok ||
- !reuseport || !sk2->sk_reuseport ||
- !reuseport_cb_ok ||
- (sk2->sk_state != TCP_TIME_WAIT &&
- !uid_eq(uid, sock_i_uid(sk2)))) {
- if (inet_rcv_saddr_equal(sk, sk2, true))
- break;
- }
- }
+ return false;
}
- return sk2 != NULL;
+
+ if (tb2 && check_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok))
+ return true;
+
+ net = sock_net(sk);
+
+ /* check there's no conflict with an existing IPV6_ADDR_ANY (if ipv6) or
+ * INADDR_ANY (if ipv4) socket.
+ */
+ hash = get_bhash2_nulladdr_hash(sk, net, port);
+ head2 = &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+
+ l3mdev = inet_sk_bound_l3mdev(sk);
+ inet_bind_bucket_for_each(tb2, &head2->chain)
+ if (check_bind2_bucket_match_nulladdr(tb2, net, port, l3mdev, sk))
+ break;
+
+ if (tb2 && check_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok))
+ return true;
+
+ return false;
}
/*
@@ -191,16 +274,20 @@ static int inet_csk_bind_conflict(const struct sock *sk,
* inet_bind_hashbucket lock held.
*/
static struct inet_bind_hashbucket *
-inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret)
+inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret,
+ struct inet_bind2_bucket **tb2_ret,
+ struct inet_bind2_hashbucket **head2_ret, int *port_ret)
{
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
- int port = 0;
+ struct inet_bind2_hashbucket *head2;
struct inet_bind_hashbucket *head;
struct net *net = sock_net(sk);
- bool relax = false;
int i, low, high, attempt_half;
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
u32 remaining, offset;
+ bool relax = false;
+ int port = 0;
int l3mdev;
l3mdev = inet_sk_bound_l3mdev(sk);
@@ -239,10 +326,12 @@ other_parity_scan:
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
spin_lock_bh(&head->lock);
+ tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
+ &head2);
inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
- if (!inet_csk_bind_conflict(sk, tb, relax, false))
+ if (check_bind_bucket_match(tb, net, port, l3mdev)) {
+ if (!inet_csk_bind_conflict(sk, port, tb, tb2,
+ relax, false))
goto success;
goto next_port;
}
@@ -272,6 +361,8 @@ next_port:
success:
*port_ret = port;
*tb_ret = tb;
+ *tb2_ret = tb2;
+ *head2_ret = head2;
return head;
}
@@ -367,54 +458,81 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
{
bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
- int ret = 1, port = snum;
+ bool bhash_created = false, bhash2_created = false;
+ struct inet_bind2_bucket *tb2 = NULL;
+ struct inet_bind2_hashbucket *head2;
+ struct inet_bind_bucket *tb = NULL;
struct inet_bind_hashbucket *head;
struct net *net = sock_net(sk);
- struct inet_bind_bucket *tb = NULL;
+ int ret = 1, port = snum;
+ bool found_port = false;
int l3mdev;
l3mdev = inet_sk_bound_l3mdev(sk);
if (!port) {
- head = inet_csk_find_open_port(sk, &tb, &port);
+ head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port);
if (!head)
return ret;
+ if (tb && tb2)
+ goto success;
+ found_port = true;
+ } else {
+ head = &hinfo->bhash[inet_bhashfn(net, port,
+ hinfo->bhash_size)];
+ spin_lock_bh(&head->lock);
+ inet_bind_bucket_for_each(tb, &head->chain)
+ if (check_bind_bucket_match(tb, net, port, l3mdev))
+ break;
+
+ tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk,
+ &head2);
+ }
+
+ if (!tb) {
+ tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net,
+ head, port, l3mdev);
if (!tb)
- goto tb_not_found;
- goto success;
+ goto fail_unlock;
+ bhash_created = true;
}
- head = &hinfo->bhash[inet_bhashfn(net, port,
- hinfo->bhash_size)];
- spin_lock_bh(&head->lock);
- inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port)
- goto tb_found;
-tb_not_found:
- tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
- net, head, port, l3mdev);
- if (!tb)
- goto fail_unlock;
-tb_found:
- if (!hlist_empty(&tb->owners)) {
+
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+ net, head2, port, l3mdev, sk);
+ if (!tb2)
+ goto fail_unlock;
+ bhash2_created = true;
+ }
+
+ /* If we had to find an open port, we already checked for conflicts */
+ if (!found_port && !hlist_empty(&tb->owners)) {
if (sk->sk_reuse == SK_FORCE_REUSE)
goto success;
if ((tb->fastreuse > 0 && reuse) ||
sk_reuseport_match(tb, sk))
goto success;
- if (inet_csk_bind_conflict(sk, tb, true, true))
+ if (inet_csk_bind_conflict(sk, port, tb, tb2, true, true))
goto fail_unlock;
}
success:
inet_csk_update_fastreuse(tb, sk);
if (!inet_csk(sk)->icsk_bind_hash)
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
WARN_ON(inet_csk(sk)->icsk_bind_hash != tb);
+ WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2);
ret = 0;
fail_unlock:
+ if (ret) {
+ if (bhash_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ if (bhash2_created)
+ inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+ tb2);
+ }
spin_unlock_bh(&head->lock);
return ret;
}
@@ -961,6 +1079,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
inet_sk_set_state(newsk, TCP_SYN_RECV);
newicsk->icsk_bind_hash = NULL;
+ newicsk->icsk_bind2_hash = NULL;
inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port;
inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num;
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 87354e20009a..e8de5e699b3f 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -81,6 +81,41 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
return tb;
}
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+ struct net *net,
+ struct inet_bind2_hashbucket *head,
+ const unsigned short port,
+ int l3mdev,
+ const struct sock *sk)
+{
+ struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+ if (tb) {
+ write_pnet(&tb->ib_net, net);
+ tb->l3mdev = l3mdev;
+ tb->port = port;
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+ else
+#endif
+ tb->rcv_saddr = sk->sk_rcv_saddr;
+ INIT_HLIST_HEAD(&tb->owners);
+ hlist_add_head(&tb->node, &head->chain);
+ }
+ return tb;
+}
+
+static bool bind2_bucket_addr_match(struct inet_bind2_bucket *tb2, struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+ &sk->sk_v6_rcv_saddr);
+#endif
+ return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
/*
* Caller must hold hashbucket lock for this tb with local BH disabled
*/
@@ -92,12 +127,25 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
}
}
+/* Caller must hold the lock for the corresponding hashbucket in the bhash table
+ * with local BH disabled
+ */
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
+{
+ if (hlist_empty(&tb->owners)) {
+ __hlist_del(&tb->node);
+ kmem_cache_free(cachep, tb);
+ }
+}
+
void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
- const unsigned short snum)
+ struct inet_bind2_bucket *tb2, const unsigned short snum)
{
inet_sk(sk)->inet_num = snum;
sk_add_bind_node(sk, &tb->owners);
inet_csk(sk)->icsk_bind_hash = tb;
+ sk_add_bind2_node(sk, &tb2->owners);
+ inet_csk(sk)->icsk_bind2_hash = tb2;
}
/*
@@ -109,6 +157,7 @@ static void __inet_put_port(struct sock *sk)
const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
hashinfo->bhash_size);
struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
spin_lock(&head->lock);
@@ -117,6 +166,13 @@ static void __inet_put_port(struct sock *sk)
inet_csk(sk)->icsk_bind_hash = NULL;
inet_sk(sk)->inet_num = 0;
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
+
+ if (inet_csk(sk)->icsk_bind2_hash) {
+ tb2 = inet_csk(sk)->icsk_bind2_hash;
+ __sk_del_bind2_node(sk);
+ inet_csk(sk)->icsk_bind2_hash = NULL;
+ inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
+ }
spin_unlock(&head->lock);
}
@@ -133,14 +189,19 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
unsigned short port = inet_sk(child)->inet_num;
const int bhash = inet_bhashfn(sock_net(sk), port,
- table->bhash_size);
+ table->bhash_size);
struct inet_bind_hashbucket *head = &table->bhash[bhash];
+ struct inet_bind2_hashbucket *head_bhash2;
+ bool created_inet_bind_bucket = false;
+ struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
int l3mdev;
spin_lock(&head->lock);
tb = inet_csk(sk)->icsk_bind_hash;
- if (unlikely(!tb)) {
+ tb2 = inet_csk(sk)->icsk_bind2_hash;
+ if (unlikely(!tb || !tb2)) {
spin_unlock(&head->lock);
return -ENOENT;
}
@@ -153,25 +214,45 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
* as that of the child socket. We have to look up or
* create a new bind bucket for the child here. */
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), sock_net(sk)) &&
- tb->l3mdev == l3mdev && tb->port == port)
+ if (check_bind_bucket_match(tb, net, port, l3mdev))
break;
}
if (!tb) {
tb = inet_bind_bucket_create(table->bind_bucket_cachep,
- sock_net(sk), head, port,
- l3mdev);
+ net, head, port, l3mdev);
if (!tb) {
spin_unlock(&head->lock);
return -ENOMEM;
}
+ created_inet_bind_bucket = true;
}
inet_csk_update_fastreuse(tb, child);
+
+ goto bhash2_find;
+ } else if (!bind2_bucket_addr_match(tb2, child)) {
+ l3mdev = inet_sk_bound_l3mdev(sk);
+
+bhash2_find:
+ tb2 = inet_bind2_bucket_find(table, net, port, l3mdev, child,
+ &head_bhash2);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
+ net, head_bhash2, port,
+ l3mdev, child);
+ if (!tb2)
+ goto error;
+ }
}
- inet_bind_hash(child, tb, port);
+ inet_bind_hash(child, tb, tb2, port);
spin_unlock(&head->lock);
return 0;
+
+error:
+ if (created_inet_bind_bucket)
+ inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+ spin_unlock(&head->lock);
+ return -ENOMEM;
}
EXPORT_SYMBOL_GPL(__inet_inherit_port);
@@ -675,6 +756,76 @@ void inet_unhash(struct sock *sk)
}
EXPORT_SYMBOL_GPL(inet_unhash);
+static bool check_bind2_bucket_match(struct inet_bind2_bucket *tb,
+ struct net *net, unsigned short port,
+ int l3mdev, struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
+}
+
+bool check_bind2_bucket_match_nulladdr(struct inet_bind2_bucket *tb,
+ struct net *net, const unsigned short port,
+ int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ struct in6_addr nulladdr = {};
+
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &nulladdr);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
+}
+
+static struct inet_bind2_hashbucket *
+inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
+ const struct net *net, unsigned short port)
+{
+ u32 hash;
+
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ hash = ipv6_portaddr_hash(net, &sk->sk_v6_rcv_saddr, port);
+ else
+#endif
+ hash = ipv4_portaddr_hash(net, sk->sk_rcv_saddr, port);
+ return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+/* This should only be called when the spinlock for the socket's corresponding
+ * bind_hashbucket is held
+ */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
+ const unsigned short port, int l3mdev, struct sock *sk,
+ struct inet_bind2_hashbucket **head)
+{
+ struct inet_bind2_bucket *bhash2 = NULL;
+ struct inet_bind2_hashbucket *h;
+
+ h = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ inet_bind_bucket_for_each(bhash2, &h->chain) {
+ if (check_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+ break;
+ }
+
+ if (head)
+ *head = h;
+
+ return bhash2;
+}
+
/* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm
* Note that we use 32bit integers (vs RFC 'short integers')
* because 2^16 is not a multiple of num_ephemeral and this
@@ -695,10 +846,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
{
struct inet_hashinfo *hinfo = death_row->hashinfo;
struct inet_timewait_sock *tw = NULL;
+ struct inet_bind2_hashbucket *head2;
struct inet_bind_hashbucket *head;
int port = inet_sk(sk)->inet_num;
struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
+ bool tb_created = false;
u32 remaining, offset;
int ret, i, low, high;
int l3mdev;
@@ -755,8 +909,7 @@ other_parity_scan:
* the established check is already unique enough.
*/
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
+ if (check_bind_bucket_match(tb, net, port, l3mdev)) {
if (tb->fastreuse >= 0 ||
tb->fastreuseport >= 0)
goto next_port;
@@ -774,6 +927,7 @@ other_parity_scan:
spin_unlock_bh(&head->lock);
return -ENOMEM;
}
+ tb_created = true;
tb->fastreuse = -1;
tb->fastreuseport = -1;
goto ok;
@@ -789,6 +943,17 @@ next_port:
return -EADDRNOTAVAIL;
ok:
+ /* Find the corresponding tb2 bucket since we need to
+ * add the socket to the bhash2 table as well
+ */
+ tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk, &head2);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
+ head2, port, l3mdev, sk);
+ if (!tb2)
+ goto error;
+ }
+
/* Here we want to add a little bit of randomness to the next source
* port that will be chosen. We use a max() with a random here so that
* on low contention the randomness is maximal and on high contention
@@ -798,7 +963,7 @@ ok:
WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
/* Head lock still held and bh's disabled */
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
if (sk_unhashed(sk)) {
inet_sk(sk)->inet_sport = htons(port);
inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -810,6 +975,12 @@ ok:
inet_twsk_deschedule_put(tw);
local_bh_enable();
return 0;
+
+error:
+ if (tb_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ spin_unlock_bh(&head->lock);
+ return -ENOMEM;
}
/*
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 028513d3e2a2..9984d23a7f3e 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4604,6 +4604,12 @@ void __init tcp_init(void)
SLAB_HWCACHE_ALIGN | SLAB_PANIC |
SLAB_ACCOUNT,
NULL);
+ tcp_hashinfo.bind2_bucket_cachep =
+ kmem_cache_create("tcp_bind2_bucket",
+ sizeof(struct inet_bind2_bucket), 0,
+ SLAB_HWCACHE_ALIGN | SLAB_PANIC |
+ SLAB_ACCOUNT,
+ NULL);
/* Size and allocate the main established and bind bucket
* hash tables.
@@ -4626,8 +4632,9 @@ void __init tcp_init(void)
if (inet_ehash_locks_alloc(&tcp_hashinfo))
panic("TCP: failed to alloc ehash_locks");
tcp_hashinfo.bhash =
- alloc_large_system_hash("TCP bind",
- sizeof(struct inet_bind_hashbucket),
+ alloc_large_system_hash("TCP bind bhash tables",
+ sizeof(struct inet_bind_hashbucket) +
+ sizeof(struct inet_bind2_hashbucket),
tcp_hashinfo.ehash_mask + 1,
17, /* one slot per 128 KB of memory */
0,
@@ -4636,9 +4643,12 @@ void __init tcp_init(void)
0,
64 * 1024);
tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size;
+ tcp_hashinfo.bhash2 =
+ (struct inet_bind2_hashbucket *)(tcp_hashinfo.bhash + tcp_hashinfo.bhash_size);
for (i = 0; i < tcp_hashinfo.bhash_size; i++) {
spin_lock_init(&tcp_hashinfo.bhash[i].lock);
INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain);
+ INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain);
}