summaryrefslogtreecommitdiffstats
path: root/net/ipv4
diff options
context:
space:
mode:
authorJakub Kicinski <kuba@kernel.org>2022-08-24 19:30:19 -0700
committerJakub Kicinski <kuba@kernel.org>2022-08-24 19:30:19 -0700
commitc21e1bf4d8104fb77bbd99f3d1276c0d7c9b28d9 (patch)
tree73ff51aa3bc284c611e346b14f4ebe8b90f251b8 /net/ipv4
parent0bf73255d3a3cf3b0416e95f2c9f7c53095c2e1a (diff)
parent1be9ac87a75a4fc0e2cc254e412d2d67a58a7191 (diff)
downloadlinux-c21e1bf4d8104fb77bbd99f3d1276c0d7c9b28d9.tar.bz2
Merge branch 'add-a-second-bind-table-hashed-by-port-and-address'
Joanne Koong says: ==================== Add a second bind table hashed by port and address Currently, there is one bind hashtable (bhash) that hashes by port only. This patchset adds a second bind table (bhash2) that hashes by port and address. The motivation for adding bhash2 is to expedite bind requests in situations where the port has many sockets in its bhash table entry (eg a large number of sockets bound to different addresses on the same port), which makes checking bind conflicts costly especially given that we acquire the table entry spinlock while doing so, which can cause softirq cpu lockups and can prevent new tcp connections. We ran into this problem at Meta where the traffic team binds a large number of IPs to port 443 and the bind() call took a significant amount of time which led to cpu softirq lockups, which caused packet drops and other failures on the machine. When experimentally testing this on a local server for ~24k sockets bound to the port, the results seen were: ipv4: before - 0.002317 seconds with bhash2 - 0.000020 seconds ipv6: before - 0.002431 seconds with bhash2 - 0.000021 seconds The additions to the initial bhash2 submission [0] are: * Updating bhash2 in the cases where a socket's rcv saddr changes after it has * been bound * Adding locks for bhash2 hashbuckets [0] https://lore.kernel.org/netdev/20220520001834.2247810-1-kuba@kernel.org/ v3: https://lore.kernel.org/netdev/20220722195406.1304948-2-joannelkoong@gmail.com/ v2: https://lore.kernel.org/netdev/20220712235310.1935121-1-joannelkoong@gmail.com/ v1: https://lore.kernel.org/netdev/20220623234242.2083895-2-joannelkoong@gmail.com/ ==================== Link: https://lore.kernel.org/r/20220822181023.3979645-1-joannelkoong@gmail.com Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Diffstat (limited to 'net/ipv4')
-rw-r--r--net/ipv4/af_inet.c26
-rw-r--r--net/ipv4/inet_connection_sock.c275
-rw-r--r--net/ipv4/inet_hashtables.c268
-rw-r--r--net/ipv4/tcp.c11
-rw-r--r--net/ipv4/tcp_ipv4.c23
5 files changed, 518 insertions, 85 deletions
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 3ca0cc467886..2e6a3cb06815 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1219,6 +1219,7 @@ EXPORT_SYMBOL(inet_unregister_protosw);
static int inet_sk_reselect_saddr(struct sock *sk)
{
+ struct inet_bind_hashbucket *prev_addr_hashbucket;
struct inet_sock *inet = inet_sk(sk);
__be32 old_saddr = inet->inet_saddr;
__be32 daddr = inet->inet_daddr;
@@ -1226,6 +1227,7 @@ static int inet_sk_reselect_saddr(struct sock *sk)
struct rtable *rt;
__be32 new_saddr;
struct ip_options_rcu *inet_opt;
+ int err;
inet_opt = rcu_dereference_protected(inet->inet_opt,
lockdep_sock_is_held(sk));
@@ -1240,20 +1242,34 @@ static int inet_sk_reselect_saddr(struct sock *sk)
if (IS_ERR(rt))
return PTR_ERR(rt);
- sk_setup_caps(sk, &rt->dst);
-
new_saddr = fl4->saddr;
- if (new_saddr == old_saddr)
+ if (new_saddr == old_saddr) {
+ sk_setup_caps(sk, &rt->dst);
return 0;
+ }
+
+ prev_addr_hashbucket =
+ inet_bhashfn_portaddr(sk->sk_prot->h.hashinfo, sk,
+ sock_net(sk), inet->inet_num);
+
+ inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
+
+ err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+ if (err) {
+ inet->inet_saddr = old_saddr;
+ inet->inet_rcv_saddr = old_saddr;
+ ip_rt_put(rt);
+ return err;
+ }
+
+ sk_setup_caps(sk, &rt->dst);
if (READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) > 1) {
pr_info("%s(): shifting inet->saddr from %pI4 to %pI4\n",
__func__, &old_saddr, &new_saddr);
}
- inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
-
/*
* XXX The only one ugly spot where we need to
* XXX really change the sockets identity after
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index eb31c7158b39..f0038043b661 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -130,14 +130,75 @@ void inet_get_local_port_range(struct net *net, int *low, int *high)
}
EXPORT_SYMBOL(inet_get_local_port_range);
+static bool inet_use_bhash2_on_bind(const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6) {
+ int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
+
+ return addr_type != IPV6_ADDR_ANY &&
+ addr_type != IPV6_ADDR_MAPPED;
+ }
+#endif
+ return sk->sk_rcv_saddr != htonl(INADDR_ANY);
+}
+
+static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2,
+ kuid_t sk_uid, bool relax,
+ bool reuseport_cb_ok, bool reuseport_ok)
+{
+ int bound_dev_if2;
+
+ if (sk == sk2)
+ return false;
+
+ bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
+
+ if (!sk->sk_bound_dev_if || !bound_dev_if2 ||
+ sk->sk_bound_dev_if == bound_dev_if2) {
+ if (sk->sk_reuse && sk2->sk_reuse &&
+ sk2->sk_state != TCP_LISTEN) {
+ if (!relax || (!reuseport_ok && sk->sk_reuseport &&
+ sk2->sk_reuseport && reuseport_cb_ok &&
+ (sk2->sk_state == TCP_TIME_WAIT ||
+ uid_eq(sk_uid, sock_i_uid(sk2)))))
+ return true;
+ } else if (!reuseport_ok || !sk->sk_reuseport ||
+ !sk2->sk_reuseport || !reuseport_cb_ok ||
+ (sk2->sk_state != TCP_TIME_WAIT &&
+ !uid_eq(sk_uid, sock_i_uid(sk2)))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool inet_bhash2_conflict(const struct sock *sk,
+ const struct inet_bind2_bucket *tb2,
+ kuid_t sk_uid,
+ bool relax, bool reuseport_cb_ok,
+ bool reuseport_ok)
+{
+ struct sock *sk2;
+
+ sk_for_each_bound_bhash2(sk2, &tb2->owners) {
+ if (sk->sk_family == AF_INET && ipv6_only_sock(sk2))
+ continue;
+
+ if (inet_bind_conflict(sk, sk2, sk_uid, relax,
+ reuseport_cb_ok, reuseport_ok))
+ return true;
+ }
+ return false;
+}
+
+/* This should be called only when the tb and tb2 hashbuckets' locks are held */
static int inet_csk_bind_conflict(const struct sock *sk,
const struct inet_bind_bucket *tb,
+ const struct inet_bind2_bucket *tb2, /* may be null */
bool relax, bool reuseport_ok)
{
- struct sock *sk2;
bool reuseport_cb_ok;
- bool reuse = sk->sk_reuse;
- bool reuseport = !!sk->sk_reuseport;
struct sock_reuseport *reuseport_cb;
kuid_t uid = sock_i_uid((struct sock *)sk);
@@ -150,55 +211,87 @@ static int inet_csk_bind_conflict(const struct sock *sk,
/*
* Unlike other sk lookup places we do not check
* for sk_net here, since _all_ the socks listed
- * in tb->owners list belong to the same net - the
- * one this bucket belongs to.
+ * in tb->owners and tb2->owners list belong
+ * to the same net - the one this bucket belongs to.
*/
- sk_for_each_bound(sk2, &tb->owners) {
- int bound_dev_if2;
+ if (!inet_use_bhash2_on_bind(sk)) {
+ struct sock *sk2;
- if (sk == sk2)
- continue;
- bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
- if ((!sk->sk_bound_dev_if ||
- !bound_dev_if2 ||
- sk->sk_bound_dev_if == bound_dev_if2)) {
- if (reuse && sk2->sk_reuse &&
- sk2->sk_state != TCP_LISTEN) {
- if ((!relax ||
- (!reuseport_ok &&
- reuseport && sk2->sk_reuseport &&
- reuseport_cb_ok &&
- (sk2->sk_state == TCP_TIME_WAIT ||
- uid_eq(uid, sock_i_uid(sk2))))) &&
- inet_rcv_saddr_equal(sk, sk2, true))
- break;
- } else if (!reuseport_ok ||
- !reuseport || !sk2->sk_reuseport ||
- !reuseport_cb_ok ||
- (sk2->sk_state != TCP_TIME_WAIT &&
- !uid_eq(uid, sock_i_uid(sk2)))) {
- if (inet_rcv_saddr_equal(sk, sk2, true))
- break;
- }
- }
+ sk_for_each_bound(sk2, &tb->owners)
+ if (inet_bind_conflict(sk, sk2, uid, relax,
+ reuseport_cb_ok, reuseport_ok) &&
+ inet_rcv_saddr_equal(sk, sk2, true))
+ return true;
+
+ return false;
+ }
+
+ /* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if
+ * ipv4) should have been checked already. We need to do these two
+ * checks separately because their spinlocks have to be acquired/released
+ * independently of each other, to prevent possible deadlocks
+ */
+ return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok);
+}
+
+/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or
+ * INADDR_ANY (if ipv4) socket.
+ *
+ * Caller must hold bhash hashbucket lock with local bh disabled, to protect
+ * against concurrent binds on the port for addr any
+ */
+static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev,
+ bool relax, bool reuseport_ok)
+{
+ kuid_t uid = sock_i_uid((struct sock *)sk);
+ const struct net *net = sock_net(sk);
+ struct sock_reuseport *reuseport_cb;
+ struct inet_bind_hashbucket *head2;
+ struct inet_bind2_bucket *tb2;
+ bool reuseport_cb_ok;
+
+ rcu_read_lock();
+ reuseport_cb = rcu_dereference(sk->sk_reuseport_cb);
+ /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */
+ reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks);
+ rcu_read_unlock();
+
+ head2 = inet_bhash2_addr_any_hashbucket(sk, net, port);
+
+ spin_lock(&head2->lock);
+
+ inet_bind_bucket_for_each(tb2, &head2->chain)
+ if (inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk))
+ break;
+
+ if (tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+ reuseport_ok)) {
+ spin_unlock(&head2->lock);
+ return true;
}
- return sk2 != NULL;
+
+ spin_unlock(&head2->lock);
+ return false;
}
/*
* Find an open port number for the socket. Returns with the
- * inet_bind_hashbucket lock held.
+ * inet_bind_hashbucket locks held if successful.
*/
static struct inet_bind_hashbucket *
-inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret)
+inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret,
+ struct inet_bind2_bucket **tb2_ret,
+ struct inet_bind_hashbucket **head2_ret, int *port_ret)
{
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
int port = 0;
- struct inet_bind_hashbucket *head;
+ struct inet_bind_hashbucket *head, *head2;
struct net *net = sock_net(sk);
bool relax = false;
int i, low, high, attempt_half;
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
u32 remaining, offset;
int l3mdev;
@@ -239,11 +332,20 @@ other_parity_scan:
head = &hinfo->bhash[inet_bhashfn(net, port,
hinfo->bhash_size)];
spin_lock_bh(&head->lock);
+ if (inet_use_bhash2_on_bind(sk)) {
+ if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false))
+ goto next_port;
+ }
+
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ spin_lock(&head2->lock);
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
- if (!inet_csk_bind_conflict(sk, tb, relax, false))
+ if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
+ if (!inet_csk_bind_conflict(sk, tb, tb2,
+ relax, false))
goto success;
+ spin_unlock(&head2->lock);
goto next_port;
}
tb = NULL;
@@ -272,6 +374,8 @@ next_port:
success:
*port_ret = port;
*tb_ret = tb;
+ *tb2_ret = tb2;
+ *head2_ret = head2;
return head;
}
@@ -368,53 +472,95 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
int ret = 1, port = snum;
- struct inet_bind_hashbucket *head;
struct net *net = sock_net(sk);
+ bool found_port = false, check_bind_conflict = true;
+ bool bhash_created = false, bhash2_created = false;
+ struct inet_bind_hashbucket *head, *head2;
+ struct inet_bind2_bucket *tb2 = NULL;
struct inet_bind_bucket *tb = NULL;
+ bool head2_lock_acquired = false;
int l3mdev;
l3mdev = inet_sk_bound_l3mdev(sk);
if (!port) {
- head = inet_csk_find_open_port(sk, &tb, &port);
+ head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port);
if (!head)
return ret;
+
+ head2_lock_acquired = true;
+
+ if (tb && tb2)
+ goto success;
+ found_port = true;
+ } else {
+ head = &hinfo->bhash[inet_bhashfn(net, port,
+ hinfo->bhash_size)];
+ spin_lock_bh(&head->lock);
+ inet_bind_bucket_for_each(tb, &head->chain)
+ if (inet_bind_bucket_match(tb, net, port, l3mdev))
+ break;
+ }
+
+ if (!tb) {
+ tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net,
+ head, port, l3mdev);
if (!tb)
- goto tb_not_found;
- goto success;
+ goto fail_unlock;
+ bhash_created = true;
}
- head = &hinfo->bhash[inet_bhashfn(net, port,
- hinfo->bhash_size)];
- spin_lock_bh(&head->lock);
- inet_bind_bucket_for_each(tb, &head->chain)
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port)
- goto tb_found;
-tb_not_found:
- tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
- net, head, port, l3mdev);
- if (!tb)
- goto fail_unlock;
-tb_found:
- if (!hlist_empty(&tb->owners)) {
- if (sk->sk_reuse == SK_FORCE_REUSE)
- goto success;
- if ((tb->fastreuse > 0 && reuse) ||
- sk_reuseport_match(tb, sk))
- goto success;
- if (inet_csk_bind_conflict(sk, tb, true, true))
+ if (!found_port) {
+ if (!hlist_empty(&tb->owners)) {
+ if (sk->sk_reuse == SK_FORCE_REUSE ||
+ (tb->fastreuse > 0 && reuse) ||
+ sk_reuseport_match(tb, sk))
+ check_bind_conflict = false;
+ }
+
+ if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) {
+ if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true))
+ goto fail_unlock;
+ }
+
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ spin_lock(&head2->lock);
+ head2_lock_acquired = true;
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+ }
+
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+ net, head2, port, l3mdev, sk);
+ if (!tb2)
goto fail_unlock;
+ bhash2_created = true;
}
+
+ if (!found_port && check_bind_conflict) {
+ if (inet_csk_bind_conflict(sk, tb, tb2, true, true))
+ goto fail_unlock;
+ }
+
success:
inet_csk_update_fastreuse(tb, sk);
if (!inet_csk(sk)->icsk_bind_hash)
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
WARN_ON(inet_csk(sk)->icsk_bind_hash != tb);
+ WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2);
ret = 0;
fail_unlock:
+ if (ret) {
+ if (bhash_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ if (bhash2_created)
+ inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+ tb2);
+ }
+ if (head2_lock_acquired)
+ spin_unlock(&head2->lock);
spin_unlock_bh(&head->lock);
return ret;
}
@@ -962,6 +1108,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
inet_sk_set_state(newsk, TCP_SYN_RECV);
newicsk->icsk_bind_hash = NULL;
+ newicsk->icsk_bind2_hash = NULL;
inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port;
inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num;
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index b9d995b5ce24..60d77e234a68 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -92,12 +92,75 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
}
}
+bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net,
+ unsigned short port, int l3mdev)
+{
+ return net_eq(ib_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev;
+}
+
+static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb,
+ struct net *net,
+ struct inet_bind_hashbucket *head,
+ unsigned short port, int l3mdev,
+ const struct sock *sk)
+{
+ write_pnet(&tb->ib_net, net);
+ tb->l3mdev = l3mdev;
+ tb->port = port;
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+ else
+#endif
+ tb->rcv_saddr = sk->sk_rcv_saddr;
+ INIT_HLIST_HEAD(&tb->owners);
+ hlist_add_head(&tb->node, &head->chain);
+}
+
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+ struct net *net,
+ struct inet_bind_hashbucket *head,
+ unsigned short port,
+ int l3mdev,
+ const struct sock *sk)
+{
+ struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+ if (tb)
+ inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk);
+
+ return tb;
+}
+
+/* Caller must hold hashbucket lock for this tb with local BH disabled */
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
+{
+ if (hlist_empty(&tb->owners)) {
+ __hlist_del(&tb->node);
+ kmem_cache_free(cachep, tb);
+ }
+}
+
+static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2,
+ const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+ &sk->sk_v6_rcv_saddr);
+#endif
+ return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
- const unsigned short snum)
+ struct inet_bind2_bucket *tb2, unsigned short port)
{
- inet_sk(sk)->inet_num = snum;
+ inet_sk(sk)->inet_num = port;
sk_add_bind_node(sk, &tb->owners);
inet_csk(sk)->icsk_bind_hash = tb;
+ sk_add_bind2_node(sk, &tb2->owners);
+ inet_csk(sk)->icsk_bind2_hash = tb2;
}
/*
@@ -109,6 +172,9 @@ static void __inet_put_port(struct sock *sk)
const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
hashinfo->bhash_size);
struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
+ struct inet_bind_hashbucket *head2 =
+ inet_bhashfn_portaddr(hashinfo, sk, sock_net(sk),
+ inet_sk(sk)->inet_num);
struct inet_bind_bucket *tb;
spin_lock(&head->lock);
@@ -117,6 +183,17 @@ static void __inet_put_port(struct sock *sk)
inet_csk(sk)->icsk_bind_hash = NULL;
inet_sk(sk)->inet_num = 0;
inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
+
+ spin_lock(&head2->lock);
+ if (inet_csk(sk)->icsk_bind2_hash) {
+ struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash;
+
+ __sk_del_bind2_node(sk);
+ inet_csk(sk)->icsk_bind2_hash = NULL;
+ inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
+ }
+ spin_unlock(&head2->lock);
+
spin_unlock(&head->lock);
}
@@ -135,12 +212,21 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
const int bhash = inet_bhashfn(sock_net(sk), port,
table->bhash_size);
struct inet_bind_hashbucket *head = &table->bhash[bhash];
+ struct inet_bind_hashbucket *head2 =
+ inet_bhashfn_portaddr(table, child, sock_net(sk), port);
+ bool created_inet_bind_bucket = false;
+ bool update_fastreuse = false;
+ struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
int l3mdev;
spin_lock(&head->lock);
+ spin_lock(&head2->lock);
tb = inet_csk(sk)->icsk_bind_hash;
- if (unlikely(!tb)) {
+ tb2 = inet_csk(sk)->icsk_bind2_hash;
+ if (unlikely(!tb || !tb2)) {
+ spin_unlock(&head2->lock);
spin_unlock(&head->lock);
return -ENOENT;
}
@@ -153,25 +239,49 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
* as that of the child socket. We have to look up or
* create a new bind bucket for the child here. */
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), sock_net(sk)) &&
- tb->l3mdev == l3mdev && tb->port == port)
+ if (inet_bind_bucket_match(tb, net, port, l3mdev))
break;
}
if (!tb) {
tb = inet_bind_bucket_create(table->bind_bucket_cachep,
- sock_net(sk), head, port,
- l3mdev);
+ net, head, port, l3mdev);
if (!tb) {
+ spin_unlock(&head2->lock);
spin_unlock(&head->lock);
return -ENOMEM;
}
+ created_inet_bind_bucket = true;
+ }
+ update_fastreuse = true;
+
+ goto bhash2_find;
+ } else if (!inet_bind2_bucket_addr_match(tb2, child)) {
+ l3mdev = inet_sk_bound_l3mdev(sk);
+
+bhash2_find:
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
+ net, head2, port,
+ l3mdev, child);
+ if (!tb2)
+ goto error;
}
- inet_csk_update_fastreuse(tb, child);
}
- inet_bind_hash(child, tb, port);
+ if (update_fastreuse)
+ inet_csk_update_fastreuse(tb, child);
+ inet_bind_hash(child, tb, tb2, port);
+ spin_unlock(&head2->lock);
spin_unlock(&head->lock);
return 0;
+
+error:
+ if (created_inet_bind_bucket)
+ inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+ spin_unlock(&head2->lock);
+ spin_unlock(&head->lock);
+ return -ENOMEM;
}
EXPORT_SYMBOL_GPL(__inet_inherit_port);
@@ -675,6 +785,112 @@ void inet_unhash(struct sock *sk)
}
EXPORT_SYMBOL_GPL(inet_unhash);
+static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb,
+ const struct net *net, unsigned short port,
+ int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
+}
+
+bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net,
+ unsigned short port, int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+ struct in6_addr addr_any = {};
+
+ if (sk->sk_family == AF_INET6)
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev &&
+ ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any);
+ else
+#endif
+ return net_eq(ib2_net(tb), net) && tb->port == port &&
+ tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
+}
+
+/* The socket's bhash2 hashbucket spinlock must be held when this is called */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net,
+ unsigned short port, int l3mdev, const struct sock *sk)
+{
+ struct inet_bind2_bucket *bhash2 = NULL;
+
+ inet_bind_bucket_for_each(bhash2, &head->chain)
+ if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+ break;
+
+ return bhash2;
+}
+
+struct inet_bind_hashbucket *
+inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port)
+{
+ struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+ u32 hash;
+#if IS_ENABLED(CONFIG_IPV6)
+ struct in6_addr addr_any = {};
+
+ if (sk->sk_family == AF_INET6)
+ hash = ipv6_portaddr_hash(net, &addr_any, port);
+ else
+#endif
+ hash = ipv4_portaddr_hash(net, 0, port);
+
+ return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
+{
+ struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+ struct inet_bind2_bucket *tb2, *new_tb2;
+ int l3mdev = inet_sk_bound_l3mdev(sk);
+ struct inet_bind_hashbucket *head2;
+ int port = inet_sk(sk)->inet_num;
+ struct net *net = sock_net(sk);
+
+ /* Allocate a bind2 bucket ahead of time to avoid permanently putting
+ * the bhash2 table in an inconsistent state if a new tb2 bucket
+ * allocation fails.
+ */
+ new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC);
+ if (!new_tb2)
+ return -ENOMEM;
+
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+
+ if (prev_saddr) {
+ spin_lock_bh(&prev_saddr->lock);
+ __sk_del_bind2_node(sk);
+ inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+ inet_csk(sk)->icsk_bind2_hash);
+ spin_unlock_bh(&prev_saddr->lock);
+ }
+
+ spin_lock_bh(&head2->lock);
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+ if (!tb2) {
+ tb2 = new_tb2;
+ inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk);
+ }
+ sk_add_bind2_node(sk, &tb2->owners);
+ inet_csk(sk)->icsk_bind2_hash = tb2;
+ spin_unlock_bh(&head2->lock);
+
+ if (tb2 != new_tb2)
+ kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2);
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr);
+
/* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm
* Note that we use 32bit integers (vs RFC 'short integers')
* because 2^16 is not a multiple of num_ephemeral and this
@@ -694,11 +910,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
struct sock *, __u16, struct inet_timewait_sock **))
{
struct inet_hashinfo *hinfo = death_row->hashinfo;
+ struct inet_bind_hashbucket *head, *head2;
struct inet_timewait_sock *tw = NULL;
- struct inet_bind_hashbucket *head;
int port = inet_sk(sk)->inet_num;
struct net *net = sock_net(sk);
+ struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
+ bool tb_created = false;
u32 remaining, offset;
int ret, i, low, high;
int l3mdev;
@@ -755,8 +973,7 @@ other_parity_scan:
* the established check is already unique enough.
*/
inet_bind_bucket_for_each(tb, &head->chain) {
- if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
- tb->port == port) {
+ if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
if (tb->fastreuse >= 0 ||
tb->fastreuseport >= 0)
goto next_port;
@@ -774,6 +991,7 @@ other_parity_scan:
spin_unlock_bh(&head->lock);
return -ENOMEM;
}
+ tb_created = true;
tb->fastreuse = -1;
tb->fastreuseport = -1;
goto ok;
@@ -789,6 +1007,20 @@ next_port:
return -EADDRNOTAVAIL;
ok:
+ /* Find the corresponding tb2 bucket since we need to
+ * add the socket to the bhash2 table as well
+ */
+ head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+ spin_lock(&head2->lock);
+
+ tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+ if (!tb2) {
+ tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
+ head2, port, l3mdev, sk);
+ if (!tb2)
+ goto error;
+ }
+
/* Here we want to add a little bit of randomness to the next source
* port that will be chosen. We use a max() with a random here so that
* on low contention the randomness is maximal and on high contention
@@ -798,7 +1030,10 @@ ok:
WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
/* Head lock still held and bh's disabled */
- inet_bind_hash(sk, tb, port);
+ inet_bind_hash(sk, tb, tb2, port);
+
+ spin_unlock(&head2->lock);
+
if (sk_unhashed(sk)) {
inet_sk(sk)->inet_sport = htons(port);
inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -810,6 +1045,13 @@ ok:
inet_twsk_deschedule_put(tw);
local_bh_enable();
return 0;
+
+error:
+ spin_unlock(&head2->lock);
+ if (tb_created)
+ inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+ spin_unlock_bh(&head->lock);
+ return -ENOMEM;
}
/*
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index ba62f8e8bd15..baf6adb723ad 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4742,6 +4742,12 @@ void __init tcp_init(void)
SLAB_HWCACHE_ALIGN | SLAB_PANIC |
SLAB_ACCOUNT,
NULL);
+ tcp_hashinfo.bind2_bucket_cachep =
+ kmem_cache_create("tcp_bind2_bucket",
+ sizeof(struct inet_bind2_bucket), 0,
+ SLAB_HWCACHE_ALIGN | SLAB_PANIC |
+ SLAB_ACCOUNT,
+ NULL);
/* Size and allocate the main established and bind bucket
* hash tables.
@@ -4765,7 +4771,7 @@ void __init tcp_init(void)
panic("TCP: failed to alloc ehash_locks");
tcp_hashinfo.bhash =
alloc_large_system_hash("TCP bind",
- sizeof(struct inet_bind_hashbucket),
+ 2 * sizeof(struct inet_bind_hashbucket),
tcp_hashinfo.ehash_mask + 1,
17, /* one slot per 128 KB of memory */
0,
@@ -4774,9 +4780,12 @@ void __init tcp_init(void)
0,
64 * 1024);
tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size;
+ tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size;
for (i = 0; i < tcp_hashinfo.bhash_size; i++) {
spin_lock_init(&tcp_hashinfo.bhash[i].lock);
INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain);
+ spin_lock_init(&tcp_hashinfo.bhash2[i].lock);
+ INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain);
}
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 0c83780dc9bf..cc2ad67f75be 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -199,11 +199,12 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr,
/* This will initiate an outgoing connection. */
int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
{
+ struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
+ __be32 daddr, nexthop, prev_sk_rcv_saddr;
struct inet_sock *inet = inet_sk(sk);
struct tcp_sock *tp = tcp_sk(sk);
__be16 orig_sport, orig_dport;
- __be32 daddr, nexthop;
struct flowi4 *fl4;
struct rtable *rt;
int err;
@@ -246,10 +247,28 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
if (!inet_opt || !inet_opt->opt.srr)
daddr = fl4->daddr;
- if (!inet->inet_saddr)
+ if (!inet->inet_saddr) {
+ if (inet_csk(sk)->icsk_bind2_hash) {
+ prev_addr_hashbucket = inet_bhashfn_portaddr(&tcp_hashinfo,
+ sk, sock_net(sk),
+ inet->inet_num);
+ prev_sk_rcv_saddr = sk->sk_rcv_saddr;
+ }
inet->inet_saddr = fl4->saddr;
+ }
+
sk_rcv_saddr_set(sk, inet->inet_saddr);
+ if (prev_addr_hashbucket) {
+ err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+ if (err) {
+ inet->inet_saddr = 0;
+ sk_rcv_saddr_set(sk, prev_sk_rcv_saddr);
+ ip_rt_put(rt);
+ return err;
+ }
+ }
+
if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
/* Reset inherited state */
tp->rx_opt.ts_recent = 0;