Commit 2eaa8575 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'net-tls-fix-scatter-gather-list-issues'



Jakub Kicinski says:

====================
net: tls: fix scatter-gather list issues

This series kicked of by a syzbot report fixes three issues around
scatter gather handling in the TLS code. First patch fixes a use-
-after-free situation which may occur if record was freed on error.
This could have already happened in BPF paths, and patch 2 now makes
the same condition occur in non-BPF code.

Patch 2 fixes the problem spotted by syzbot. If encryption failed
we have to clean the end markings from scatter gather list. As
suggested by John the patch frees the record entirely and caller
may retry copying data from user space buffer again.

Third patch fixes a bug in the TLS 1.3 code spotted while working
on patch 2. TLS 1.3 may effectively overflow the SG list which
leads to the BUG() in sg_page() being triggered.

Patch 4 adds a test case which triggers this bug reliably.

Next two patches are small cleanups of dead code and code which
makes dangerous assumptions.

Last but not least two minor improvements to the sockmap tests.

Tested:
 - bpf/test_sockmap
 - net/tls
 - syzbot repro (which used error injection, hence no direct
   selftest is added to preserve it).
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 81b6b964 e5dc9dd3
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
#include <net/strparser.h>

#define MAX_MSG_FRAGS			MAX_SKB_FRAGS
#define NR_MSG_FRAG_IDS			(MAX_MSG_FRAGS + 1)

enum __sk_action {
	__SK_DROP = 0,
@@ -29,13 +30,15 @@ struct sk_msg_sg {
	u32				size;
	u32				copybreak;
	unsigned long			copy;
	/* The extra element is used for chaining the front and sections when
	 * the list becomes partitioned (e.g. end < start). The crypto APIs
	 * require the chaining.
	/* The extra two elements:
	 * 1) used for chaining the front and sections when the list becomes
	 *    partitioned (e.g. end < start). The crypto APIs require the
	 *    chaining;
	 * 2) to chain tailer SG entries after the message.
	 */
	struct scatterlist		data[MAX_MSG_FRAGS + 1];
	struct scatterlist		data[MAX_MSG_FRAGS + 2];
};
static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS);
static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);

/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
struct sk_msg {
@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)

static inline u32 sk_msg_iter_dist(u32 start, u32 end)
{
	return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
	return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
}

#define sk_msg_iter_var_prev(var)			\
	do {						\
		if (var == 0)				\
			var = MAX_MSG_FRAGS - 1;	\
			var = NR_MSG_FRAG_IDS - 1;	\
		else					\
			var--;				\
	} while (0)
@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end)
#define sk_msg_iter_var_next(var)			\
	do {						\
		var++;					\
		if (var == MAX_MSG_FRAGS)		\
		if (var == NR_MSG_FRAG_IDS)		\
			var = 0;			\
	} while (0)

@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)

static inline void sk_msg_init(struct sk_msg *msg)
{
	BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
	BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
	memset(msg, 0, sizeof(*msg));
	sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
	sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
}

static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)

static inline bool sk_msg_full(const struct sk_msg *msg)
{
	return (msg->sg.end == msg->sg.start) && msg->sg.size;
	return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
}

static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
	if (sk_msg_full(msg))
		return MAX_MSG_FRAGS;

	return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
}

+1 −2
Original line number Diff line number Diff line
@@ -100,7 +100,6 @@ struct tls_rec {
	struct list_head list;
	int tx_ready;
	int tx_flags;
	int inplace_crypto;

	struct sk_msg msg_plaintext;
	struct sk_msg msg_encrypted;
@@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
		int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
			    int flags);
bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);

static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
+4 −4
Original line number Diff line number Diff line
@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
	WARN_ON_ONCE(last_sge == first_sge);
	shift = last_sge > first_sge ?
		last_sge - first_sge - 1 :
		MAX_SKB_FRAGS - first_sge + last_sge - 1;
		NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
	if (!shift)
		goto out;

@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
	do {
		u32 move_from;

		if (i + shift >= MAX_MSG_FRAGS)
			move_from = i + shift - MAX_MSG_FRAGS;
		if (i + shift >= NR_MSG_FRAG_IDS)
			move_from = i + shift - NR_MSG_FRAG_IDS;
		else
			move_from = i + shift;
		if (move_from == msg->sg.end)
@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
	} while (1);

	msg->sg.end = msg->sg.end - shift > msg->sg.end ?
		      msg->sg.end - shift + MAX_MSG_FRAGS :
		      msg->sg.end - shift + NR_MSG_FRAG_IDS :
		      msg->sg.end - shift;
out:
	msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
+1 −1
Original line number Diff line number Diff line
@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
	copied = skb->len;
	msg->sg.start = 0;
	msg->sg.size = copied;
	msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
	msg->sg.end = num_sge;
	msg->skb = skb;

	sk_psock_queue_msg(psock, msg);
+1 −1
Original line number Diff line number Diff line
@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
				struct sk_msg *msg, int *copied, int flags)
{
	bool cork = false, enospc = msg->sg.start == msg->sg.end;
	bool cork = false, enospc = sk_msg_full(msg);
	struct sock *sk_redir;
	u32 tosend, delta = 0;
	int ret;
Loading