diff options
-rw-r--r-- | include/linux/skmsg.h | 12 | ||||
-rw-r--r-- | include/linux/socket.h | 1 | ||||
-rw-r--r-- | include/net/tls.h | 9 | ||||
-rw-r--r-- | net/core/filter.c | 24 | ||||
-rw-r--r-- | net/core/skmsg.c | 23 | ||||
-rw-r--r-- | net/ipv4/tcp_bpf.c | 15 | ||||
-rw-r--r-- | net/tls/tls_main.c | 14 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 43 | ||||
-rw-r--r-- | tools/testing/selftests/bpf/test_verifier.c | 2 |
9 files changed, 112 insertions, 31 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index eb8f6cb84c10..178a3933a71b 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -36,9 +36,7 @@ struct sk_msg_sg { struct scatterlist data[MAX_MSG_FRAGS + 1]; }; -/* UAPI in filter.c depends on struct sk_msg_sg being first element. If - * this is moved filter.c also must be updated. - */ +/* UAPI in filter.c depends on struct sk_msg_sg being first element. */ struct sk_msg { struct sk_msg_sg sg; void *data; @@ -419,6 +417,14 @@ static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock) sk_psock_drop(sk, psock); } +static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock) +{ + if (psock->parser.enabled) + psock->parser.saved_data_ready(sk); + else + sk->sk_data_ready(sk); +} + static inline void psock_set_prog(struct bpf_prog **pprog, struct bpf_prog *prog) { diff --git a/include/linux/socket.h b/include/linux/socket.h index 8b571e9b9f76..84c48a3c0227 100644 --- a/include/linux/socket.h +++ b/include/linux/socket.h @@ -286,6 +286,7 @@ struct ucred { #define MSG_NOSIGNAL 0x4000 /* Do not generate SIGPIPE */ #define MSG_MORE 0x8000 /* Sender will send more */ #define MSG_WAITFORONE 0x10000 /* recvmmsg(): block until 1+ packets avail */ +#define MSG_SENDPAGE_NOPOLICY 0x10000 /* sendpage() internal : do no apply policy */ #define MSG_SENDPAGE_NOTLAST 0x20000 /* sendpage() internal : not the last page */ #define MSG_BATCH 0x40000 /* sendmmsg(): more messages coming */ #define MSG_EOF MSG_FIN diff --git a/include/net/tls.h b/include/net/tls.h index bab5627ff5e3..23601f3b02ee 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -454,6 +454,15 @@ tls_offload_ctx_tx(const struct tls_context *tls_ctx) return (struct tls_offload_context_tx *)tls_ctx->priv_ctx_tx; } +static inline bool tls_sw_has_ctx_tx(const struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (!ctx) + return false; + return !!tls_sw_ctx_tx(ctx); +} + static inline struct tls_offload_context_rx * tls_offload_ctx_rx(const struct tls_context *tls_ctx) { diff --git a/net/core/filter.c b/net/core/filter.c index 3a3b21726fb5..447dd1bad31f 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -6313,6 +6313,9 @@ static bool sk_msg_is_valid_access(int off, int size, if (type == BPF_WRITE) return false; + if (off % size != 0) + return false; + switch (off) { case offsetof(struct sk_msg_md, data): info->reg_type = PTR_TO_PACKET; @@ -6324,16 +6327,20 @@ static bool sk_msg_is_valid_access(int off, int size, if (size != sizeof(__u64)) return false; break; - default: + case bpf_ctx_range(struct sk_msg_md, family): + case bpf_ctx_range(struct sk_msg_md, remote_ip4): + case bpf_ctx_range(struct sk_msg_md, local_ip4): + case bpf_ctx_range_till(struct sk_msg_md, remote_ip6[0], remote_ip6[3]): + case bpf_ctx_range_till(struct sk_msg_md, local_ip6[0], local_ip6[3]): + case bpf_ctx_range(struct sk_msg_md, remote_port): + case bpf_ctx_range(struct sk_msg_md, local_port): + case bpf_ctx_range(struct sk_msg_md, size): if (size != sizeof(__u32)) return false; - } - - if (off < 0 || off >= sizeof(struct sk_msg_md)) - return false; - if (off % size != 0) + break; + default: return false; - + } return true; } @@ -7418,6 +7425,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, int off; #endif + /* convert ctx uses the fact sg element is first in struct */ + BUILD_BUG_ON(offsetof(struct sk_msg, sg) != 0); + switch (si->off) { case offsetof(struct sk_msg_md, data): *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data), diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 56a99d0c9aa0..86c9726fced8 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -403,7 +403,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) msg->skb = skb; sk_psock_queue_msg(psock, msg); - sk->sk_data_ready(sk); + sk_psock_data_ready(sk, psock); return copied; } @@ -572,6 +572,7 @@ void sk_psock_drop(struct sock *sk, struct sk_psock *psock) { rcu_assign_sk_user_data(sk, NULL); sk_psock_cork_free(psock); + sk_psock_zap_ingress(psock); sk_psock_restore_proto(sk, psock); write_lock_bh(&sk->sk_callback_lock); @@ -669,6 +670,22 @@ static void sk_psock_verdict_apply(struct sk_psock *psock, bool ingress; switch (verdict) { + case __SK_PASS: + sk_other = psock->sk; + if (sock_flag(sk_other, SOCK_DEAD) || + !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { + goto out_free; + } + if (atomic_read(&sk_other->sk_rmem_alloc) <= + sk_other->sk_rcvbuf) { + struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); + + tcp->bpf.flags |= BPF_F_INGRESS; + skb_queue_tail(&psock->ingress_skb, skb); + schedule_work(&psock->work); + break; + } + goto out_free; case __SK_REDIRECT: sk_other = tcp_skb_bpf_redirect_fetch(skb); if (unlikely(!sk_other)) @@ -735,7 +752,7 @@ static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) } /* Called with socket lock held. */ -static void sk_psock_data_ready(struct sock *sk) +static void sk_psock_strp_data_ready(struct sock *sk) { struct sk_psock *psock; @@ -783,7 +800,7 @@ void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) return; parser->saved_data_ready = sk->sk_data_ready; - sk->sk_data_ready = sk_psock_data_ready; + sk->sk_data_ready = sk_psock_strp_data_ready; sk->sk_write_space = sk_psock_write_space; parser->enabled = true; } diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index a47c1cdf90fc..1bb7321a256d 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -8,6 +8,7 @@ #include <linux/wait.h> #include <net/inet_common.h> +#include <net/tls.h> static bool tcp_bpf_stream_read(const struct sock *sk) { @@ -198,7 +199,7 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, msg->sg.start = i; msg->sg.size -= apply_bytes; sk_psock_queue_msg(psock, tmp); - sk->sk_data_ready(sk); + sk_psock_data_ready(sk, psock); } else { sk_msg_free(sk, tmp); kfree(tmp); @@ -218,6 +219,8 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, u32 off; while (1) { + bool has_tx_ulp; + sge = sk_msg_elem(msg, msg->sg.start); size = (apply && apply_bytes < sge->length) ? apply_bytes : sge->length; @@ -226,7 +229,15 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, tcp_rate_check_app_limited(sk); retry: - ret = do_tcp_sendpages(sk, page, off, size, flags); + has_tx_ulp = tls_sw_has_ctx_tx(sk); + if (has_tx_ulp) { + flags |= MSG_SENDPAGE_NOPOLICY; + ret = kernel_sendpage_locked(sk, + page, off, size, flags); + } else { + ret = do_tcp_sendpages(sk, page, off, size, flags); + } + if (ret <= 0) return ret; if (apply) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 311cec8e533d..acff12999c06 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -55,6 +55,8 @@ enum { static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex); static LIST_HEAD(device_list); static DEFINE_MUTEX(device_mutex); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; @@ -690,6 +692,16 @@ static int tls_init(struct sock *sk) mutex_unlock(&tcpv6_prot_mutex); } + if (ip_ver == TLSV4 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], sk->sk_prot); + smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } + ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); @@ -721,8 +733,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { static int __init tls_register(void) { - build_protos(tls_prots[TLSV4], &tcp_prot); - tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index d4ecc66464e6..5aee9ae5ca53 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -686,12 +686,13 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, struct sk_psock *psock; struct sock *sk_redir; struct tls_rec *rec; + bool enospc, policy; int err = 0, send; u32 delta = 0; - bool enospc; + policy = !(flags & MSG_SENDPAGE_NOPOLICY); psock = sk_psock_get(sk); - if (!psock) + if (!psock || !policy) return tls_push_record(sk, flags, record_type); more_data: enospc = sk_msg_full(msg); @@ -1017,8 +1018,8 @@ send_end: return copied ? copied : ret; } -int tls_sw_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags) +int tls_sw_do_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) { long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -1033,15 +1034,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, int ret = 0; bool eor; - if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | - MSG_SENDPAGE_NOTLAST)) - return -ENOTSUPP; - - /* No MSG_EOR from splice, only look at MSG_MORE */ eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); - - lock_sock(sk); - sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); /* Wait till there is any pending write on socket */ @@ -1145,10 +1138,34 @@ wait_for_memory: } sendpage_end: ret = sk_stream_error(sk, flags, ret); - release_sock(sk); return copied ? copied : ret; } +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) + return -ENOTSUPP; + + return tls_sw_do_sendpage(sk, page, offset, size, flags); +} + +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + int ret; + + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) + return -ENOTSUPP; + + lock_sock(sk); + ret = tls_sw_do_sendpage(sk, page, offset, size, flags); + release_sock(sk); + return ret; +} + static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, int flags, long timeo, int *err) { diff --git a/tools/testing/selftests/bpf/test_verifier.c b/tools/testing/selftests/bpf/test_verifier.c index b246931c46ef..dbd31750b214 100644 --- a/tools/testing/selftests/bpf/test_verifier.c +++ b/tools/testing/selftests/bpf/test_verifier.c @@ -1879,7 +1879,7 @@ static struct bpf_test tests[] = { offsetof(struct sk_msg_md, size) + 4), BPF_EXIT_INSN(), }, - .errstr = "R0 !read_ok", + .errstr = "invalid bpf_context access", .result = REJECT, .prog_type = BPF_PROG_TYPE_SK_MSG, }, |