diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 90 |
1 files changed, 83 insertions, 7 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 43252a801c3f..277f7c209fed 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -39,6 +39,7 @@ #include <linux/netdevice.h> #include <linux/sched/signal.h> #include <linux/inetdevice.h> +#include <linux/inet_diag.h> #include <net/tls.h> @@ -251,14 +252,26 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } -void tls_ctx_free(struct tls_context *ctx) +/** + * tls_ctx_free() - free TLS ULP context + * @sk: socket to with @ctx is attached + * @ctx: TLS context structure + * + * Free TLS context. If @sk is %NULL caller guarantees that the socket + * to which @ctx was attached has no outstanding references. + */ +void tls_ctx_free(struct sock *sk, struct tls_context *ctx) { if (!ctx) return; memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); - kfree(ctx); + + if (sk) + kfree_rcu(ctx, rcu); + else + kfree(ctx); } static void tls_sk_proto_cleanup(struct sock *sk, @@ -306,7 +319,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) write_lock_bh(&sk->sk_callback_lock); if (free_ctx) - icsk->icsk_ulp_data = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); sk->sk_prot = ctx->sk_proto; if (sk->sk_write_space == tls_write_space) sk->sk_write_space = ctx->sk_write_space; @@ -321,7 +334,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ctx->sk_proto_close(sk, timeout); if (free_ctx) - tls_ctx_free(ctx); + tls_ctx_free(sk, ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -610,7 +623,7 @@ static struct tls_context *create_ctx(struct sock *sk) if (!ctx) return NULL; - icsk->icsk_ulp_data = ctx; + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); ctx->setsockopt = sk->sk_prot->setsockopt; ctx->getsockopt = sk->sk_prot->getsockopt; ctx->sk_proto_close = sk->sk_prot->close; @@ -651,8 +664,8 @@ static void tls_hw_sk_destruct(struct sock *sk) ctx->sk_destruct(sk); /* Free ctx */ - tls_ctx_free(ctx); - icsk->icsk_ulp_data = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + tls_ctx_free(sk, ctx); } static int tls_hw_prot(struct sock *sk) @@ -823,6 +836,67 @@ static void tls_update(struct sock *sk, struct proto *p) } } +static int tls_get_info(const struct sock *sk, struct sk_buff *skb) +{ + u16 version, cipher_type; + struct tls_context *ctx; + struct nlattr *start; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!ctx) { + err = 0; + goto nla_failure; + } + version = ctx->prot_info.version; + if (version) { + err = nla_put_u16(skb, TLS_INFO_VERSION, version); + if (err) + goto nla_failure; + } + cipher_type = ctx->prot_info.cipher_type; + if (cipher_type) { + err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); + if (err) + goto nla_failure; + } + err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); + if (err) + goto nla_failure; + + err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); + if (err) + goto nla_failure; + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; +} + +static size_t tls_get_info_size(const struct sock *sk) +{ + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ + 0; + + return size; +} + void tls_register_device(struct tls_device *device) { spin_lock_bh(&device_spinlock); @@ -844,6 +918,8 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .owner = THIS_MODULE, .init = tls_init, .update = tls_update, + .get_info = tls_get_info, + .get_info_size = tls_get_info_size, }; static int __init tls_register(void) |