Commit 71489e21 authored by Joe Stringer's avatar Joe Stringer Committed by Alexei Starovoitov
Browse files

net: Track socket refcounts in skb_steal_sock()



Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.

Signed-off-by: default avatarJoe Stringer <joe@wand.net.nz>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarMartin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
parent cf7fbe66
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
					      int iif, int sdif,
					      bool *refcounted)
{
	struct sock *sk = skb_steal_sock(skb);
	struct sock *sk = skb_steal_sock(skb, refcounted);

	*refcounted = true;
	if (sk)
		return sk;

+1 −2
Original line number Diff line number Diff line
@@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
					     const int sdif,
					     bool *refcounted)
{
	struct sock *sk = skb_steal_sock(skb);
	struct sock *sk = skb_steal_sock(skb, refcounted);
	const struct iphdr *iph = ip_hdr(skb);

	*refcounted = true;
	if (sk)
		return sk;

+9 −1
Original line number Diff line number Diff line
@@ -2537,15 +2537,23 @@ skb_sk_is_prefetched(struct sk_buff *skb)
#endif /* CONFIG_INET */
}

static inline struct sock *skb_steal_sock(struct sk_buff *skb)
/**
 * skb_steal_sock
 * @skb to steal the socket from
 * @refcounted is set to true if the socket is reference-counted
 */
static inline struct sock *
skb_steal_sock(struct sk_buff *skb, bool *refcounted)
{
	if (skb->sk) {
		struct sock *sk = skb->sk;

		*refcounted = true;
		skb->destructor = NULL;
		skb->sk = NULL;
		return sk;
	}
	*refcounted = false;
	return NULL;
}

+4 −2
Original line number Diff line number Diff line
@@ -2288,6 +2288,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
	struct rtable *rt = skb_rtable(skb);
	__be32 saddr, daddr;
	struct net *net = dev_net(skb->dev);
	bool refcounted;

	/*
	 *  Validate the packet.
@@ -2313,7 +2314,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
	if (udp4_csum_init(skb, uh, proto))
		goto csum_error;

	sk = skb_steal_sock(skb);
	sk = skb_steal_sock(skb, &refcounted);
	if (sk) {
		struct dst_entry *dst = skb_dst(skb);
		int ret;
@@ -2322,6 +2323,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
			udp_sk_rx_dst_set(sk, dst);

		ret = udp_unicast_rcv_skb(sk, skb, uh);
		if (refcounted)
			sock_put(sk);
		return ret;
	}
+6 −3
Original line number Diff line number Diff line
@@ -843,6 +843,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
	struct net *net = dev_net(skb->dev);
	struct udphdr *uh;
	struct sock *sk;
	bool refcounted;
	u32 ulen = 0;

	if (!pskb_may_pull(skb, sizeof(struct udphdr)))
@@ -879,7 +880,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
		goto csum_error;

	/* Check if the socket is already available, e.g. due to early demux */
	sk = skb_steal_sock(skb);
	sk = skb_steal_sock(skb, &refcounted);
	if (sk) {
		struct dst_entry *dst = skb_dst(skb);
		int ret;
@@ -888,11 +889,13 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
			udp6_sk_rx_dst_set(sk, dst);

		if (!uh->check && !udp_sk(sk)->no_check6_rx) {
			if (refcounted)
				sock_put(sk);
			goto report_csum_error;
		}

		ret = udp6_unicast_rcv_skb(sk, skb, uh);
		if (refcounted)
			sock_put(sk);
		return ret;
	}