diff --git a/drivers/block/nbd.c b/drivers/block/nbd.c index fe63f3c55d0d..9003baf52be8 100644 --- a/drivers/block/nbd.c +++ b/drivers/block/nbd.c @@ -45,6 +45,8 @@ #include #include #include +#include +#include #define CREATE_TRACE_POINTS #include @@ -302,6 +304,19 @@ static int nbd_disconnected(struct nbd_config *config) test_bit(NBD_RT_DISCONNECT_REQUESTED, &config->runtime_flags); } +static void nbd_sock_shutdown(struct sock *sk) +{ + if (sk_is_stream_unix(sk)) { + kernel_sock_shutdown(sk->sk_socket, SHUT_RDWR); + return; + } + + if (lock_sock_try(sk)) { + inet_shutdown_locked(sk->sk_socket, SHUT_RDWR); + release_sock(sk); + } +} + static void nbd_mark_nsock_dead(struct nbd_device *nbd, struct nbd_sock *nsock, int notify) { @@ -315,7 +330,8 @@ static void nbd_mark_nsock_dead(struct nbd_device *nbd, struct nbd_sock *nsock, } } if (!nsock->dead) { - kernel_sock_shutdown(nsock->sock, SHUT_RDWR); + nbd_sock_shutdown(nsock->sock->sk); + if (atomic_dec_return(&nbd->config->live_connections) == 0) { if (test_and_clear_bit(NBD_RT_DISCONNECT_REQUESTED, &nbd->config->runtime_flags)) { @@ -548,6 +564,22 @@ static enum blk_eh_timer_return nbd_xmit_timeout(struct request *req) return BLK_EH_DONE; } +static int nbd_sock_sendmsg(struct socket *sock, struct msghdr *msg) +{ + struct sock *sk = sock->sk; + int err = -ERESTARTSYS; + + if (sk_is_stream_unix(sk)) + return sock_sendmsg(sock, msg); + + if (lock_sock_try(sk)) { + err = tcp_sendmsg_locked(sk, msg, msg_data_left(msg)); + release_sock(sk); + } + + return err; +} + static int __sock_xmit(struct nbd_device *nbd, struct socket *sock, int send, struct iov_iter *iter, int msg_flags, int *sent) { @@ -573,7 +605,7 @@ static int __sock_xmit(struct nbd_device *nbd, struct socket *sock, int send, msg.msg_flags = msg_flags | MSG_NOSIGNAL; if (send) - result = sock_sendmsg(sock, &msg); + result = nbd_sock_sendmsg(sock, &msg); else result = sock_recvmsg(sock, &msg, msg.msg_flags); @@ -1228,6 +1260,13 @@ static struct socket *nbd_get_socket(struct nbd_device *nbd, unsigned long fd, return NULL; } + if (READ_ONCE(sock->sk->sk_state) != TCP_ESTABLISHED) { + dev_err(disk_to_dev(nbd->disk), "Unsupported socket: not connected yet.\n"); + *err = -ENOTCONN; + sockfd_put(sock); + return NULL; + } + if (sock->ops->shutdown == sock_no_shutdown) { dev_err(disk_to_dev(nbd->disk), "Unsupported socket: shutdown callout must be supported.\n"); *err = -EINVAL; diff --git a/include/net/inet_common.h b/include/net/inet_common.h index 5dd2bf24449e..c085c39573c9 100644 --- a/include/net/inet_common.h +++ b/include/net/inet_common.h @@ -38,6 +38,7 @@ void inet_splice_eof(struct socket *sock); int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, int flags); int inet_shutdown(struct socket *sock, int how); +int inet_shutdown_locked(struct socket *sock, int how); int inet_listen(struct socket *sock, int backlog); int __inet_listen_sk(struct sock *sk, int backlog); void inet_sock_destruct(struct sock *sk); diff --git a/include/net/sock.h b/include/net/sock.h index 6c9a83016e95..203a60661fce 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1710,6 +1710,24 @@ static inline void lock_sock(struct sock *sk) } void __lock_sock(struct sock *sk); + +static inline bool lock_sock_try(struct sock *sk) +{ + if (!spin_trylock_bh(&sk->sk_lock.slock)) + return false; + + if (sk->sk_lock.owned) { + spin_unlock_bh(&sk->sk_lock.slock); + return false; + } + + sk->sk_lock.owned = 1; + spin_unlock_bh(&sk->sk_lock.slock); + + mutex_acquire(&sk->sk_lock.dep_map, 0, 1, _RET_IP_); + return true; +} + void __release_sock(struct sock *sk); void release_sock(struct sock *sk); diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 8036e76aa1e4..bde8ddcc28f0 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -896,21 +896,11 @@ int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, } EXPORT_SYMBOL(inet_recvmsg); -int inet_shutdown(struct socket *sock, int how) +static int __inet_shutdown(struct socket *sock, int how) { struct sock *sk = sock->sk; int err = 0; - /* This should really check to make sure - * the socket is a TCP socket. (WHY AC...) - */ - how++; /* maps 0->1 has the advantage of making bit 1 rcvs and - 1->2 bit 2 snds. - 2->3 */ - if ((how & ~SHUTDOWN_MASK) || !how) /* MAXINT->0 */ - return -EINVAL; - - lock_sock(sk); if (sock->state == SS_CONNECTING) { if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE)) @@ -947,11 +937,45 @@ int inet_shutdown(struct socket *sock, int how) /* Wake up anyone sleeping in poll. */ sk->sk_state_change(sk); + + return err; +} + +int inet_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int err; + + /* maps 0->1 has the advantage of making bit 1 rcvs and + * 1->2 bit 2 snds. + * 2->3 + */ + how++; + + if ((how & ~SHUTDOWN_MASK) || !how) + return -EINVAL; + + lock_sock(sk); + err = __inet_shutdown(sock, how); release_sock(sk); + return err; } EXPORT_SYMBOL(inet_shutdown); +int inet_shutdown_locked(struct socket *sock, int how) +{ + sock_owned_by_me(sock->sk); + + how++; + + if ((how & ~SHUTDOWN_MASK) || !how) + return -EINVAL; + + return __inet_shutdown(sock, how); +} +EXPORT_SYMBOL_GPL(inet_shutdown_locked); + /* * ioctl() calls you can issue on an INET socket. Most of these are * device configuration and stuff and very rarely used. Some ioctls