diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 161 |
1 files changed, 99 insertions, 62 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 1f3d9789af30..83d67df33f0c 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -53,18 +53,14 @@ static int tls_do_decryption(struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm = strp_msg(skb); struct aead_request *aead_req; int ret; - unsigned int req_size = sizeof(struct aead_request) + - crypto_aead_reqsize(ctx->aead_recv); - aead_req = kzalloc(req_size, flags); + aead_req = aead_request_alloc(ctx->aead_recv, flags); if (!aead_req) return -ENOMEM; - aead_request_set_tfm(aead_req, ctx->aead_recv); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); aead_request_set_crypt(aead_req, sgin, sgout, data_len + tls_ctx->rx.tag_size, @@ -74,19 +70,7 @@ static int tls_do_decryption(struct sock *sk, ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); - if (ret < 0) - goto out; - - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; - tls_advance_record_sn(sk, &tls_ctx->rx); - - ctx->decrypted = true; - - ctx->saved_data_ready(sk); - -out: - kfree(aead_req); + aead_request_free(aead_req); return ret; } @@ -224,8 +208,7 @@ static int tls_push_record(struct sock *sk, int flags, struct aead_request *req; int rc; - req = kzalloc(sizeof(struct aead_request) + - crypto_aead_reqsize(ctx->aead_send), sk->sk_allocation); + req = aead_request_alloc(ctx->aead_send, sk->sk_allocation); if (!req) return -ENOMEM; @@ -267,7 +250,7 @@ static int tls_push_record(struct sock *sk, int flags, tls_advance_record_sn(sk, &tls_ctx->tx); out_req: - kfree(req); + aead_request_free(req); return rc; } @@ -328,7 +311,12 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, } } + /* Mark the end in the last sg entry if newly added */ + if (num_elem > *pages_used) + sg_mark_end(&to[num_elem - 1]); out: + if (rc) + iov_iter_revert(from, size - *size_used); *size_used = size; *pages_used = num_elem; @@ -377,6 +365,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int record_room; bool full_record; int orig_size; + bool is_kvec = msg->msg_iter.type & ITER_KVEC; if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -ENOTSUPP; @@ -425,8 +414,7 @@ alloc_encrypted: try_to_copy -= required_size - ctx->sg_encrypted_size; full_record = true; } - - if (full_record || eor) { + if (!is_kvec && (full_record || eor)) { ret = zerocopy_from_iter(sk, &msg->msg_iter, try_to_copy, &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, @@ -438,15 +426,11 @@ alloc_encrypted: copied += try_to_copy; ret = tls_push_record(sk, msg->msg_flags, record_type); - if (!ret) - continue; - if (ret < 0) + if (ret) goto send_end; + continue; - copied -= try_to_copy; fallback_to_reg_send: - iov_iter_revert(&msg->msg_iter, - ctx->sg_plaintext_size - orig_size); trim_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, @@ -673,8 +657,38 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, return skb; } -static int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout) +static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout, bool *zc) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + int err = 0; + +#ifdef CONFIG_TLS_DEVICE + err = tls_device_decrypted(sk, skb); + if (err < 0) + return err; +#endif + if (!ctx->decrypted) { + err = decrypt_skb(sk, skb, sgout); + if (err < 0) + return err; + } else { + *zc = false; + } + + rxm->offset += tls_ctx->rx.prepend_size; + rxm->full_len -= tls_ctx->rx.overhead_size; + tls_advance_record_sn(sk, &tls_ctx->rx); + ctx->decrypted = true; + ctx->saved_data_ready(sk); + + return err; +} + +int decrypt_skb(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -764,6 +778,7 @@ int tls_sw_recvmsg(struct sock *sk, bool cmsg = false; int target, err = 0; long timeo; + bool is_kvec = msg->msg_iter.type & ITER_KVEC; flags |= nonblock; @@ -807,7 +822,7 @@ int tls_sw_recvmsg(struct sock *sk, page_count = iov_iter_npages(&msg->msg_iter, MAX_SKB_FRAGS); to_copy = rxm->full_len - tls_ctx->rx.overhead_size; - if (to_copy <= len && page_count < MAX_SKB_FRAGS && + if (!is_kvec && to_copy <= len && page_count < MAX_SKB_FRAGS && likely(!(flags & MSG_PEEK))) { struct scatterlist sgin[MAX_SKB_FRAGS + 1]; int pages = 0; @@ -824,7 +839,7 @@ int tls_sw_recvmsg(struct sock *sk, if (err < 0) goto fallback_to_reg_recv; - err = decrypt_skb(sk, skb, sgin); + err = decrypt_skb_update(sk, skb, sgin, &zc); for (; pages > 0; pages--) put_page(sg_page(&sgin[pages])); if (err < 0) { @@ -833,7 +848,7 @@ int tls_sw_recvmsg(struct sock *sk, } } else { fallback_to_reg_recv: - err = decrypt_skb(sk, skb, NULL); + err = decrypt_skb_update(sk, skb, NULL, &zc); if (err < 0) { tls_err_abort(sk, EBADMSG); goto recv_end; @@ -888,6 +903,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, int err = 0; long timeo; int chunk; + bool zc; lock_sock(sk); @@ -904,7 +920,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, } if (!ctx->decrypted) { - err = decrypt_skb(sk, skb, NULL); + err = decrypt_skb_update(sk, skb, NULL, &zc); if (err < 0) { tls_err_abort(sk, EBADMSG); @@ -950,7 +966,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - char header[tls_ctx->rx.prepend_size]; + char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; struct strp_msg *rxm = strp_msg(skb); size_t cipher_overhead; size_t data_len = 0; @@ -960,6 +976,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) return 0; + /* Sanity-check size of on-stack buffer. */ + if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { + ret = -EINVAL; + goto read_failure; + } + /* Linearize header to local buffer */ ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); @@ -987,6 +1009,10 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) goto read_failure; } +#ifdef CONFIG_TLS_DEVICE + handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, + *(u64*)tls_ctx->rx.rec_seq); +#endif return data_len + TLS_HEADER_SIZE; read_failure: @@ -999,16 +1025,13 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm; - - rxm = strp_msg(skb); ctx->decrypted = false; ctx->recv_pkt = skb; strp_pause(strp); - strp->sk->sk_state_change(strp->sk); + ctx->saved_data_ready(strp->sk); } static void tls_data_ready(struct sock *sk) @@ -1024,23 +1047,20 @@ void tls_sw_free_resources_tx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - if (ctx->aead_send) - crypto_free_aead(ctx->aead_send); + crypto_free_aead(ctx->aead_send); tls_free_both_sg(sk); kfree(ctx); } -void tls_sw_free_resources_rx(struct sock *sk) +void tls_sw_release_resources_rx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); if (ctx->aead_recv) { - if (ctx->recv_pkt) { - kfree_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; - } + kfree_skb(ctx->recv_pkt); + ctx->recv_pkt = NULL; crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); write_lock_bh(&sk->sk_callback_lock); @@ -1050,6 +1070,14 @@ void tls_sw_free_resources_rx(struct sock *sk) strp_done(&ctx->strp); lock_sock(sk); } +} + +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + + tls_sw_release_resources_rx(sk); kfree(ctx); } @@ -1074,28 +1102,38 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } if (tx) { - sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); - if (!sw_ctx_tx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_tx = sw_ctx_tx; + } else { + sw_ctx_tx = + (struct tls_sw_context_tx *)ctx->priv_ctx_tx; } - crypto_init_wait(&sw_ctx_tx->async_wait); - ctx->priv_ctx_tx = sw_ctx_tx; } else { - sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); - if (!sw_ctx_rx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_rx = sw_ctx_rx; + } else { + sw_ctx_rx = + (struct tls_sw_context_rx *)ctx->priv_ctx_rx; } - crypto_init_wait(&sw_ctx_rx->async_wait); - ctx->priv_ctx_rx = sw_ctx_rx; } if (tx) { + crypto_init_wait(&sw_ctx_tx->async_wait); crypto_info = &ctx->crypto_send; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; } else { + crypto_init_wait(&sw_ctx_rx->async_wait); crypto_info = &ctx->crypto_recv; cctx = &ctx->rx; aead = &sw_ctx_rx->aead_recv; @@ -1120,7 +1158,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } /* Sanity-check the IV size for stack allocations. */ - if (iv_size > MAX_IV_SIZE) { + if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { rc = -EINVAL; goto free_priv; } @@ -1138,12 +1176,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); cctx->rec_seq_size = rec_seq_size; - cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; goto free_iv; } - memcpy(cctx->rec_seq, rec_seq, rec_seq_size); if (sw_ctx_tx) { sg_init_table(sw_ctx_tx->sg_encrypted_data, |