Commit 53cb0995 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'wireguard-fixes'



Jason A. Donenfeld says:

====================
wireguard fixes for 5.7-rc7

Hopefully these are the last fixes for 5.7:

1) A trivial bump in the selftest harness to support gcc-10.
   build.wireguard.com is still on gcc-9 but I'll probably switch to
   gcc-10 in the coming weeks.

2) A concurrency fix regarding userspace modifying the pre-shared key at
   the same time as packets are being processed, reported by Matt
   Dunwoodie.

3) We were previously clearing skb->hash on egress, which broke
   fq_codel, cake, and other things that actually make use of the flow
   hash for queueing, reported by Dave Taht and Toke Høiland-Jørgensen.

4) A fix for the increased memory usage caused by (3). This can be
   thought of as part of patch (3), but because of the separate
   reasoning and breadth of it I thought made it a bit cleaner to put in
   a standalone commit.

Fixes (2), (3), and (4) are -stable material.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 20a785aa a9e90d99
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ enum cookie_values {
};

enum counter_values {
	COUNTER_BITS_TOTAL = 2048,
	COUNTER_BITS_TOTAL = 8192,
	COUNTER_REDUNDANT_BITS = BITS_PER_LONG,
	COUNTER_WINDOW_SIZE = COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS
};
+9 −13
Original line number Diff line number Diff line
@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_create(struct wg_peer *peer)

	if (unlikely(!keypair))
		return NULL;
	spin_lock_init(&keypair->receiving_counter.lock);
	keypair->internal_id = atomic64_inc_return(&keypair_counter);
	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
	keypair->entry.peer = peer;
@@ -358,25 +359,16 @@ out:
	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
}

static void symmetric_key_init(struct noise_symmetric_key *key)
{
	spin_lock_init(&key->counter.receive.lock);
	atomic64_set(&key->counter.counter, 0);
	memset(key->counter.receive.backtrack, 0,
	       sizeof(key->counter.receive.backtrack));
	key->birthdate = ktime_get_coarse_boottime_ns();
	key->is_valid = true;
}

static void derive_keys(struct noise_symmetric_key *first_dst,
			struct noise_symmetric_key *second_dst,
			const u8 chaining_key[NOISE_HASH_LEN])
{
	u64 birthdate = ktime_get_coarse_boottime_ns();
	kdf(first_dst->key, second_dst->key, NULL, NULL,
	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
	    chaining_key);
	symmetric_key_init(first_dst);
	symmetric_key_init(second_dst);
	first_dst->birthdate = second_dst->birthdate = birthdate;
	first_dst->is_valid = second_dst->is_valid = true;
}

static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
@@ -715,6 +707,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
	u8 e[NOISE_PUBLIC_KEY_LEN];
	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
	u8 static_private[NOISE_PUBLIC_KEY_LEN];
	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];

	down_read(&wg->static_identity.lock);

@@ -733,6 +726,8 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
	memcpy(ephemeral_private, handshake->ephemeral_private,
	       NOISE_PUBLIC_KEY_LEN);
	memcpy(preshared_key, handshake->preshared_key,
	       NOISE_SYMMETRIC_KEY_LEN);
	up_read(&handshake->lock);

	if (state != HANDSHAKE_CREATED_INITIATION)
@@ -750,7 +745,7 @@ wg_noise_handshake_consume_response(struct message_handshake_response *src,
		goto fail;

	/* psk */
	mix_psk(chaining_key, hash, key, handshake->preshared_key);
	mix_psk(chaining_key, hash, key, preshared_key);

	/* {} */
	if (!message_decrypt(NULL, src->encrypted_nothing,
@@ -783,6 +778,7 @@ out:
	memzero_explicit(chaining_key, NOISE_HASH_LEN);
	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
	up_read(&wg->static_identity.lock);
	return ret_peer;
}
+6 −8
Original line number Diff line number Diff line
@@ -15,18 +15,14 @@
#include <linux/mutex.h>
#include <linux/kref.h>

union noise_counter {
	struct {
struct noise_replay_counter {
	u64 counter;
		unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
	spinlock_t lock;
	} receive;
	atomic64_t counter;
	unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
};

struct noise_symmetric_key {
	u8 key[NOISE_SYMMETRIC_KEY_LEN];
	union noise_counter counter;
	u64 birthdate;
	bool is_valid;
};
@@ -34,7 +30,9 @@ struct noise_symmetric_key {
struct noise_keypair {
	struct index_hashtable_entry entry;
	struct noise_symmetric_key sending;
	atomic64_t sending_counter;
	struct noise_symmetric_key receiving;
	struct noise_replay_counter receiving_counter;
	__le32 remote_index;
	bool i_am_the_initiator;
	struct kref refcount;
+9 −1
Original line number Diff line number Diff line
@@ -87,12 +87,20 @@ static inline bool wg_check_packet_protocol(struct sk_buff *skb)
	return real_protocol && skb->protocol == real_protocol;
}

static inline void wg_reset_packet(struct sk_buff *skb)
static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
{
	u8 l4_hash = skb->l4_hash;
	u8 sw_hash = skb->sw_hash;
	u32 hash = skb->hash;
	skb_scrub_packet(skb, true);
	memset(&skb->headers_start, 0,
	       offsetof(struct sk_buff, headers_end) -
		       offsetof(struct sk_buff, headers_start));
	if (encapsulating) {
		skb->l4_hash = l4_hash;
		skb->sw_hash = sw_hash;
		skb->hash = hash;
	}
	skb->queue_mapping = 0;
	skb->nohdr = 0;
	skb->peeked = 0;
+22 −22
Original line number Diff line number Diff line
@@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_peer *peer)
	}
}

static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
{
	struct scatterlist sg[MAX_SKB_FRAGS + 8];
	struct sk_buff *trailer;
	unsigned int offset;
	int num_frags;

	if (unlikely(!key))
	if (unlikely(!keypair))
		return false;

	if (unlikely(!READ_ONCE(key->is_valid) ||
		  wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
		  key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
		WRITE_ONCE(key->is_valid, false);
	if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
		  wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
		  keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
		WRITE_ONCE(keypair->receiving.is_valid, false);
		return false;
	}

@@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)

	if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
					         PACKET_CB(skb)->nonce,
						 key->key))
						 keypair->receiving.key))
		return false;

	/* Another ugly situation of pushing and pulling the header so as to
@@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
}

/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
static bool counter_validate(union noise_counter *counter, u64 their_counter)
static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
{
	unsigned long index, index_current, top, i;
	bool ret = false;

	spin_lock_bh(&counter->receive.lock);
	spin_lock_bh(&counter->lock);

	if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
	if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
		     their_counter >= REJECT_AFTER_MESSAGES))
		goto out;

	++their_counter;

	if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
		     counter->receive.counter))
		     counter->counter))
		goto out;

	index = their_counter >> ilog2(BITS_PER_LONG);

	if (likely(their_counter > counter->receive.counter)) {
		index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
	if (likely(their_counter > counter->counter)) {
		index_current = counter->counter >> ilog2(BITS_PER_LONG);
		top = min_t(unsigned long, index - index_current,
			    COUNTER_BITS_TOTAL / BITS_PER_LONG);
		for (i = 1; i <= top; ++i)
			counter->receive.backtrack[(i + index_current) &
			counter->backtrack[(i + index_current) &
				((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
		counter->receive.counter = their_counter;
		counter->counter = their_counter;
	}

	index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
	ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
				&counter->receive.backtrack[index]);
				&counter->backtrack[index]);

out:
	spin_unlock_bh(&counter->receive.lock);
	spin_unlock_bh(&counter->lock);
	return ret;
}

@@ -472,19 +472,19 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
		if (unlikely(state != PACKET_STATE_CRYPTED))
			goto next;

		if (unlikely(!counter_validate(&keypair->receiving.counter,
		if (unlikely(!counter_validate(&keypair->receiving_counter,
					       PACKET_CB(skb)->nonce))) {
			net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
					    peer->device->dev->name,
					    PACKET_CB(skb)->nonce,
					    keypair->receiving.counter.receive.counter);
					    keypair->receiving_counter.counter);
			goto next;
		}

		if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
			goto next;

		wg_reset_packet(skb);
		wg_reset_packet(skb, false);
		wg_packet_consume_data_done(peer, skb, &endpoint);
		free = false;

@@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct work_struct *work)
	struct sk_buff *skb;

	while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
		enum packet_state state = likely(decrypt_packet(skb,
				&PACKET_CB(skb)->keypair->receiving)) ?
		enum packet_state state =
			likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
				PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
		wg_queue_enqueue_per_peer_napi(skb, state);
		if (need_resched())
Loading