diff options
Diffstat (limited to 'net/smc/af_smc.c')
-rw-r--r-- | net/smc/af_smc.c | 98 |
1 files changed, 67 insertions, 31 deletions
diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index cbbb947dbfcf..f3fdf3714f8b 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -47,6 +47,7 @@ static DEFINE_MUTEX(smc_create_lgr_pending); /* serialize link group */ static void smc_tcp_listen_work(struct work_struct *); +static void smc_connect_work(struct work_struct *); static void smc_set_keepalive(struct sock *sk, int val) { @@ -124,6 +125,12 @@ static int smc_release(struct socket *sock) goto out; smc = smc_sk(sk); + + /* cleanup for a dangling non-blocking connect */ + flush_work(&smc->connect_work); + kfree(smc->connect_info); + smc->connect_info = NULL; + if (sk->sk_state == SMC_LISTEN) /* smc_close_non_accepted() is called and acquires * sock lock for child sockets again @@ -188,6 +195,7 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock, sk->sk_protocol = protocol; smc = smc_sk(sk); INIT_WORK(&smc->tcp_listen_work, smc_tcp_listen_work); + INIT_WORK(&smc->connect_work, smc_connect_work); INIT_DELAYED_WORK(&smc->conn.tx_work, smc_tx_work); INIT_LIST_HEAD(&smc->accept_q); spin_lock_init(&smc->accept_q_lock); @@ -707,6 +715,35 @@ static int __smc_connect(struct smc_sock *smc) return 0; } +static void smc_connect_work(struct work_struct *work) +{ + struct smc_sock *smc = container_of(work, struct smc_sock, + connect_work); + int rc; + + lock_sock(&smc->sk); + rc = kernel_connect(smc->clcsock, &smc->connect_info->addr, + smc->connect_info->alen, smc->connect_info->flags); + if (smc->clcsock->sk->sk_err) { + smc->sk.sk_err = smc->clcsock->sk->sk_err; + goto out; + } + if (rc < 0) { + smc->sk.sk_err = -rc; + goto out; + } + + rc = __smc_connect(smc); + if (rc < 0) + smc->sk.sk_err = -rc; + +out: + smc->sk.sk_state_change(&smc->sk); + kfree(smc->connect_info); + smc->connect_info = NULL; + release_sock(&smc->sk); +} + static int smc_connect(struct socket *sock, struct sockaddr *addr, int alen, int flags) { @@ -736,15 +773,32 @@ static int smc_connect(struct socket *sock, struct sockaddr *addr, smc_copy_sock_settings_to_clc(smc); tcp_sk(smc->clcsock->sk)->syn_smc = 1; - rc = kernel_connect(smc->clcsock, addr, alen, flags); - if (rc) - goto out; + if (flags & O_NONBLOCK) { + if (smc->connect_info) { + rc = -EALREADY; + goto out; + } + smc->connect_info = kzalloc(alen + 2 * sizeof(int), GFP_KERNEL); + if (!smc->connect_info) { + rc = -ENOMEM; + goto out; + } + smc->connect_info->alen = alen; + smc->connect_info->flags = flags ^ O_NONBLOCK; + memcpy(&smc->connect_info->addr, addr, alen); + schedule_work(&smc->connect_work); + rc = -EINPROGRESS; + } else { + rc = kernel_connect(smc->clcsock, addr, alen, flags); + if (rc) + goto out; - rc = __smc_connect(smc); - if (rc < 0) - goto out; - else - rc = 0; /* success cases including fallback */ + rc = __smc_connect(smc); + if (rc < 0) + goto out; + else + rc = 0; /* success cases including fallback */ + } out: release_sock(sk); @@ -1455,39 +1509,23 @@ static __poll_t smc_accept_poll(struct sock *parent) return mask; } -static __poll_t smc_poll_mask(struct socket *sock, __poll_t events) +static __poll_t smc_poll(struct file *file, struct socket *sock, + poll_table *wait) { struct sock *sk = sock->sk; __poll_t mask = 0; struct smc_sock *smc; - int rc; if (!sk) return EPOLLNVAL; smc = smc_sk(sock->sk); - sock_hold(sk); - lock_sock(sk); if ((sk->sk_state == SMC_INIT) || smc->use_fallback) { /* delegate to CLC child sock */ - release_sock(sk); - mask = smc->clcsock->ops->poll_mask(smc->clcsock, events); - lock_sock(sk); + mask = smc->clcsock->ops->poll(file, smc->clcsock, wait); sk->sk_err = smc->clcsock->sk->sk_err; - if (sk->sk_err) { + if (sk->sk_err) mask |= EPOLLERR; - } else { - /* if non-blocking connect finished ... */ - if (sk->sk_state == SMC_INIT && - mask & EPOLLOUT && - smc->clcsock->sk->sk_state != TCP_CLOSE) { - rc = __smc_connect(smc); - if (rc < 0) - mask |= EPOLLERR; - /* success cases including fallback */ - mask |= EPOLLOUT | EPOLLWRNORM; - } - } } else { if (sk->sk_err) mask |= EPOLLERR; @@ -1516,8 +1554,6 @@ static __poll_t smc_poll_mask(struct socket *sock, __poll_t events) mask |= EPOLLPRI; } - release_sock(sk); - sock_put(sk); return mask; } @@ -1801,7 +1837,7 @@ static const struct proto_ops smc_sock_ops = { .socketpair = sock_no_socketpair, .accept = smc_accept, .getname = smc_getname, - .poll_mask = smc_poll_mask, + .poll = smc_poll, .ioctl = smc_ioctl, .listen = smc_listen, .shutdown = smc_shutdown, |