diff options
-rw-r--r-- | include/linux/filter.h | 1 | ||||
-rw-r--r-- | include/net/sock_reuseport.h | 3 | ||||
-rw-r--r-- | net/core/filter.c | 87 | ||||
-rw-r--r-- | net/core/sock_reuseport.c | 36 | ||||
-rw-r--r-- | net/ipv4/inet_hashtables.c | 14 | ||||
-rw-r--r-- | net/ipv4/udp.c | 4 | ||||
-rw-r--r-- | net/ipv6/inet6_hashtables.c | 14 | ||||
-rw-r--r-- | net/ipv6/udp.c | 4 |
8 files changed, 106 insertions, 57 deletions
diff --git a/include/linux/filter.h b/include/linux/filter.h index 70e9d57677fe..5d565c50bcb2 100644 --- a/include/linux/filter.h +++ b/include/linux/filter.h @@ -753,6 +753,7 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk); int sk_attach_bpf(u32 ufd, struct sock *sk); int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk); int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk); +void sk_reuseport_prog_free(struct bpf_prog *prog); int sk_detach_filter(struct sock *sk); int sk_get_filter(struct sock *sk, struct sock_filter __user *filter, unsigned int len); diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h index 73b569556be6..8a5f70c7cdf2 100644 --- a/include/net/sock_reuseport.h +++ b/include/net/sock_reuseport.h @@ -34,8 +34,7 @@ extern struct sock *reuseport_select_sock(struct sock *sk, u32 hash, struct sk_buff *skb, int hdr_len); -extern struct bpf_prog *reuseport_attach_prog(struct sock *sk, - struct bpf_prog *prog); +extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog); int reuseport_get_id(struct sock_reuseport *reuse); #endif /* _SOCK_REUSEPORT_H */ diff --git a/net/core/filter.c b/net/core/filter.c index 142595b4e0d1..22906b31d43f 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -1453,30 +1453,6 @@ static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk) return 0; } -static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk) -{ - struct bpf_prog *old_prog; - int err; - - if (bpf_prog_size(prog->len) > sysctl_optmem_max) - return -ENOMEM; - - if (sk_unhashed(sk) && sk->sk_reuseport) { - err = reuseport_alloc(sk, false); - if (err) - return err; - } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) { - /* The socket wasn't bound with SO_REUSEPORT */ - return -EINVAL; - } - - old_prog = reuseport_attach_prog(sk, prog); - if (old_prog) - bpf_prog_destroy(old_prog); - - return 0; -} - static struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk) { @@ -1550,13 +1526,15 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk) if (IS_ERR(prog)) return PTR_ERR(prog); - err = __reuseport_attach_prog(prog, sk); - if (err < 0) { + if (bpf_prog_size(prog->len) > sysctl_optmem_max) + err = -ENOMEM; + else + err = reuseport_attach_prog(sk, prog); + + if (err) __bpf_prog_release(prog); - return err; - } - return 0; + return err; } static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk) @@ -1586,19 +1564,58 @@ int sk_attach_bpf(u32 ufd, struct sock *sk) int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk) { - struct bpf_prog *prog = __get_bpf(ufd, sk); + struct bpf_prog *prog; int err; + if (sock_flag(sk, SOCK_FILTER_LOCKED)) + return -EPERM; + + prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER); + if (IS_ERR(prog) && PTR_ERR(prog) == -EINVAL) + prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT); if (IS_ERR(prog)) return PTR_ERR(prog); - err = __reuseport_attach_prog(prog, sk); - if (err < 0) { - bpf_prog_put(prog); - return err; + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) { + /* Like other non BPF_PROG_TYPE_SOCKET_FILTER + * bpf prog (e.g. sockmap). It depends on the + * limitation imposed by bpf_prog_load(). + * Hence, sysctl_optmem_max is not checked. + */ + if ((sk->sk_type != SOCK_STREAM && + sk->sk_type != SOCK_DGRAM) || + (sk->sk_protocol != IPPROTO_UDP && + sk->sk_protocol != IPPROTO_TCP) || + (sk->sk_family != AF_INET && + sk->sk_family != AF_INET6)) { + err = -ENOTSUPP; + goto err_prog_put; + } + } else { + /* BPF_PROG_TYPE_SOCKET_FILTER */ + if (bpf_prog_size(prog->len) > sysctl_optmem_max) { + err = -ENOMEM; + goto err_prog_put; + } } - return 0; + err = reuseport_attach_prog(sk, prog); +err_prog_put: + if (err) + bpf_prog_put(prog); + + return err; +} + +void sk_reuseport_prog_free(struct bpf_prog *prog) +{ + if (!prog) + return; + + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) + bpf_prog_put(prog); + else + bpf_prog_destroy(prog); } struct bpf_scratchpad { diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c index d260167f5f77..ba5cba56f574 100644 --- a/net/core/sock_reuseport.c +++ b/net/core/sock_reuseport.c @@ -9,6 +9,7 @@ #include <net/sock_reuseport.h> #include <linux/bpf.h> #include <linux/idr.h> +#include <linux/filter.h> #include <linux/rcupdate.h> #define INIT_SOCKS 128 @@ -133,8 +134,7 @@ static void reuseport_free_rcu(struct rcu_head *head) struct sock_reuseport *reuse; reuse = container_of(head, struct sock_reuseport, rcu); - if (reuse->prog) - bpf_prog_destroy(reuse->prog); + sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1)); if (reuse->reuseport_id) ida_simple_remove(&reuseport_ida, reuse->reuseport_id); kfree(reuse); @@ -219,9 +219,9 @@ void reuseport_detach_sock(struct sock *sk) } EXPORT_SYMBOL(reuseport_detach_sock); -static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks, - struct bpf_prog *prog, struct sk_buff *skb, - int hdr_len) +static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks, + struct bpf_prog *prog, struct sk_buff *skb, + int hdr_len) { struct sk_buff *nskb = NULL; u32 index; @@ -282,9 +282,15 @@ struct sock *reuseport_select_sock(struct sock *sk, /* paired with smp_wmb() in reuseport_add_sock() */ smp_rmb(); - if (prog && skb) - sk2 = run_bpf(reuse, socks, prog, skb, hdr_len); + if (!prog || !skb) + goto select_by_hash; + + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) + sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash); + else + sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len); +select_by_hash: /* no bpf or invalid bpf result: fall back to hash usage */ if (!sk2) sk2 = reuse->socks[reciprocal_scale(hash, socks)]; @@ -296,12 +302,21 @@ out: } EXPORT_SYMBOL(reuseport_select_sock); -struct bpf_prog * -reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) +int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) { struct sock_reuseport *reuse; struct bpf_prog *old_prog; + if (sk_unhashed(sk) && sk->sk_reuseport) { + int err = reuseport_alloc(sk, false); + + if (err) + return err; + } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) { + /* The socket wasn't bound with SO_REUSEPORT */ + return -EINVAL; + } + spin_lock_bh(&reuseport_lock); reuse = rcu_dereference_protected(sk->sk_reuseport_cb, lockdep_is_held(&reuseport_lock)); @@ -310,6 +325,7 @@ reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) rcu_assign_pointer(reuse->prog, prog); spin_unlock_bh(&reuseport_lock); - return old_prog; + sk_reuseport_prog_free(old_prog); + return 0; } EXPORT_SYMBOL(reuseport_attach_prog); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 370e24463fb7..f5c9ef2586de 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -328,7 +328,7 @@ struct sock *__inet_lookup_listener(struct net *net, saddr, sport, daddr, hnum, dif, sdif); if (result) - return result; + goto done; /* Lookup lhash2 with INADDR_ANY */ @@ -337,9 +337,10 @@ struct sock *__inet_lookup_listener(struct net *net, if (ilb2->count > ilb->count) goto port_lookup; - return inet_lhash2_lookup(net, ilb2, skb, doff, - saddr, sport, daddr, hnum, - dif, sdif); + result = inet_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, daddr, hnum, + dif, sdif); + goto done; port_lookup: sk_for_each_rcu(sk, &ilb->head) { @@ -352,12 +353,15 @@ port_lookup: result = reuseport_select_sock(sk, phash, skb, doff); if (result) - return result; + goto done; } result = sk; hiscore = score; } } +done: + if (unlikely(IS_ERR(result))) + return NULL; return result; } EXPORT_SYMBOL_GPL(__inet_lookup_listener); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 038dd7909051..f4e35b2ff8b8 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -499,6 +499,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, daddr, hnum, dif, sdif, exact_dif, hslot2, skb); } + if (unlikely(IS_ERR(result))) + return NULL; return result; } begin: @@ -513,6 +515,8 @@ begin: saddr, sport); result = reuseport_select_sock(sk, hash, skb, sizeof(struct udphdr)); + if (unlikely(IS_ERR(result))) + return NULL; if (result) return result; } diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c index 595ad408dba0..3d7c7460a0c5 100644 --- a/net/ipv6/inet6_hashtables.c +++ b/net/ipv6/inet6_hashtables.c @@ -191,7 +191,7 @@ struct sock *inet6_lookup_listener(struct net *net, saddr, sport, daddr, hnum, dif, sdif); if (result) - return result; + goto done; /* Lookup lhash2 with in6addr_any */ @@ -200,9 +200,10 @@ struct sock *inet6_lookup_listener(struct net *net, if (ilb2->count > ilb->count) goto port_lookup; - return inet6_lhash2_lookup(net, ilb2, skb, doff, - saddr, sport, daddr, hnum, - dif, sdif); + result = inet6_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, daddr, hnum, + dif, sdif); + goto done; port_lookup: sk_for_each(sk, &ilb->head) { @@ -214,12 +215,15 @@ port_lookup: result = reuseport_select_sock(sk, phash, skb, doff); if (result) - return result; + goto done; } result = sk; hiscore = score; } } +done: + if (unlikely(IS_ERR(result))) + return NULL; return result; } EXPORT_SYMBOL_GPL(inet6_lookup_listener); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index f6b96956a8ed..83f4c77c79d8 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -235,6 +235,8 @@ struct sock *__udp6_lib_lookup(struct net *net, exact_dif, hslot2, skb); } + if (unlikely(IS_ERR(result))) + return NULL; return result; } begin: @@ -249,6 +251,8 @@ begin: saddr, sport); result = reuseport_select_sock(sk, hash, skb, sizeof(struct udphdr)); + if (unlikely(IS_ERR(result))) + return NULL; if (result) return result; } |