diff --git a/include/net/udp.h b/include/net/udp.h index a061d1b22ddc..ebe53658fbc2 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -402,13 +402,13 @@ void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len); int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb); void udp_skb_destructor(struct sock *sk, struct sk_buff *skb); struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, int *off, - int *err); + int *err, bool reclaim); static inline struct sk_buff *skb_recv_udp(struct sock *sk, unsigned int flags, int *err) { int off = 0; - return __skb_recv_udp(sk, flags, &off, err); + return __skb_recv_udp(sk, flags, &off, err, true); } enum skb_drop_reason udp_v4_early_demux(struct sk_buff *skb); diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 2ac7731e1e0a..1d5a5c31e0cc 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -595,7 +595,8 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb, * from skb_consume found in __tcp_bpf_recvmsg() after its been copied * into user buffers. */ - skb_set_owner_r(skb, sk); + if (!sk_is_udp(sk)) + skb_set_owner_r(skb, sk); err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg, true); if (err < 0) kfree(msg); @@ -615,7 +616,8 @@ static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb if (unlikely(!msg)) return -EAGAIN; - skb_set_owner_r(skb, sk); + if (!sk_is_udp(sk)) + skb_set_owner_r(skb, sk); err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg, take_ref); if (err < 0) kfree(msg); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index ee63af0ef42c..82826e6e67d7 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1951,7 +1951,7 @@ int udp_ioctl(struct sock *sk, int cmd, int *karg) EXPORT_IPV6_MOD(udp_ioctl); struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, - int *off, int *err) + int *off, int *err, bool reclaim) { struct sk_buff_head *sk_queue = &sk->sk_receive_queue; struct sk_buff_head *queue; @@ -1974,8 +1974,14 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, skb = __skb_try_recv_from_queue(queue, flags, off, err, &last); if (skb) { - if (!(flags & MSG_PEEK)) + if (!(flags & MSG_PEEK)) { + if (reclaim) { + spin_lock_bh(&queue->lock); + goto reclaim; + } + udp_skb_destructor(sk, skb); + } spin_unlock_bh(&queue->lock); return skb; } @@ -1995,8 +2001,18 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, skb = __skb_try_recv_from_queue(queue, flags, off, err, &last); - if (skb && !(flags & MSG_PEEK)) - udp_skb_dtor_locked(sk, skb); + if (skb && !(flags & MSG_PEEK)) { + if (reclaim) { + int size; +reclaim: + size = udp_skb_truesize(skb); + sk_forward_alloc_add(sk, size); + atomic_sub(size, &sk->sk_rmem_alloc); + } else { + udp_skb_dtor_locked(sk, skb); + } + } + spin_unlock(&sk_queue->lock); spin_unlock_bh(&queue->lock); if (skb) @@ -2067,7 +2083,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, try_again: off = sk_peek_offset(sk, flags); - skb = __skb_recv_udp(sk, flags, &off, &err); + skb = __skb_recv_udp(sk, flags, &off, &err, false); if (!skb) return err; diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c index 0735d820e413..b6878d5a7bdd 100644 --- a/net/ipv4/udp_bpf.c +++ b/net/ipv4/udp_bpf.c @@ -51,7 +51,9 @@ static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); ret = udp_msg_has_data(sk, psock); if (!ret) { + release_sock(sk); wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + lock_sock(sk); ret = udp_msg_has_data(sk, psock); } sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); @@ -80,6 +82,7 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, goto out; } + lock_sock(sk); msg_bytes_ready: copied = sk_msg_recvmsg(sk, psock, msg, len, flags); if (!copied) { @@ -91,11 +94,17 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (data) { if (psock_has_data(psock)) goto msg_bytes_ready; + + release_sock(sk); + ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); goto out; } copied = -EAGAIN; } + + release_sock(sk); + ret = copied; out: sk_psock_put(sk, psock); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 794c13674e8a..c25dd32f0320 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -485,7 +485,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, try_again: off = sk_peek_offset(sk, flags); - skb = __skb_recv_udp(sk, flags, &off, &err); + skb = __skb_recv_udp(sk, flags, &off, &err, false); if (!skb) return err;