diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_main.c | 58 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 64 |
2 files changed, 84 insertions, 38 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 311cec8e533d..78cb4a584080 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -55,8 +55,10 @@ enum { static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex); static LIST_HEAD(device_list); -static DEFINE_MUTEX(device_mutex); +static DEFINE_SPINLOCK(device_spinlock); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; @@ -538,11 +540,14 @@ static struct tls_context *create_ctx(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; - ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); if (!ctx) return NULL; icsk->icsk_ulp_data = ctx; + ctx->setsockopt = sk->sk_prot->setsockopt; + ctx->getsockopt = sk->sk_prot->getsockopt; + ctx->sk_proto_close = sk->sk_prot->close; return ctx; } @@ -552,7 +557,7 @@ static int tls_hw_prot(struct sock *sk) struct tls_device *dev; int rc = 0; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { if (dev->feature && dev->feature(dev)) { ctx = create_ctx(sk); @@ -570,7 +575,7 @@ static int tls_hw_prot(struct sock *sk) } } out: - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); return rc; } @@ -579,12 +584,17 @@ static void tls_hw_unhash(struct sock *sk) struct tls_context *ctx = tls_get_ctx(sk); struct tls_device *dev; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->unhash) + if (dev->unhash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); dev->unhash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); ctx->unhash(sk); } @@ -595,12 +605,17 @@ static int tls_hw_hash(struct sock *sk) int err; err = ctx->hash(sk); - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->hash) + if (dev->hash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); err |= dev->hash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); if (err) tls_hw_unhash(sk); @@ -675,9 +690,6 @@ static int tls_init(struct sock *sk) rc = -ENOMEM; goto out; } - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ if (ip_ver == TLSV6 && @@ -690,6 +702,16 @@ static int tls_init(struct sock *sk) mutex_unlock(&tcpv6_prot_mutex); } + if (ip_ver == TLSV4 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], sk->sk_prot); + smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } + ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); @@ -699,17 +721,17 @@ out: void tls_register_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_add_tail(&device->dev_list, &device_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_register_device); void tls_unregister_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_del(&device->dev_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_unregister_device); @@ -721,8 +743,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { static int __init tls_register(void) { - build_protos(tls_prots[TLSV4], &tcp_prot); - tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 7b1af8b59cd2..11cdc8f7db63 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -686,16 +686,24 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, struct sk_psock *psock; struct sock *sk_redir; struct tls_rec *rec; + bool enospc, policy; int err = 0, send; - bool enospc; + u32 delta = 0; + policy = !(flags & MSG_SENDPAGE_NOPOLICY); psock = sk_psock_get(sk); - if (!psock) + if (!psock || !policy) return tls_push_record(sk, flags, record_type); more_data: enospc = sk_msg_full(msg); - if (psock->eval == __SK_NONE) + if (psock->eval == __SK_NONE) { + delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); + if (delta < msg->sg.size) + delta -= msg->sg.size; + else + delta = 0; + } if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && !enospc && !full_record) { err = -ENOSPC; @@ -743,7 +751,7 @@ more_data: msg->apply_bytes -= send; if (msg->sg.size == 0) tls_free_open_rec(sk); - *copied -= send; + *copied -= (send + delta); err = -EACCES; } @@ -935,10 +943,12 @@ fallback_to_reg_send: tls_ctx->tx.overhead_size); } - ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_pl, - try_to_copy); - if (ret < 0) - goto trim_sgl; + if (try_to_copy) { + ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, + msg_pl, try_to_copy); + if (ret < 0) + goto trim_sgl; + } /* Open records defined only if successfully copied, otherwise * we would trim the sg but not reset the open record frags. @@ -1010,8 +1020,8 @@ send_end: return copied ? copied : ret; } -int tls_sw_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags) +int tls_sw_do_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) { long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -1026,15 +1036,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, int ret = 0; bool eor; - if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | - MSG_SENDPAGE_NOTLAST)) - return -ENOTSUPP; - - /* No MSG_EOR from splice, only look at MSG_MORE */ eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); - - lock_sock(sk); - sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); /* Wait till there is any pending write on socket */ @@ -1138,10 +1140,34 @@ wait_for_memory: } sendpage_end: ret = sk_stream_error(sk, flags, ret); - release_sock(sk); return copied ? copied : ret; } +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) + return -ENOTSUPP; + + return tls_sw_do_sendpage(sk, page, offset, size, flags); +} + +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + int ret; + + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) + return -ENOTSUPP; + + lock_sock(sk); + ret = tls_sw_do_sendpage(sk, page, offset, size, flags); + release_sock(sk); + return ret; +} + static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, int flags, long timeo, int *err) { |