Commit 2bcd3502 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'wg-fixes'



Jason A. Donenfeld says:

====================
wireguard fixes for 5.8-rc3

This series contains two fixes, one cosmetic and one quite important:

1) Avoid the `if ((x = f()) == y)` pattern, from Frank
   Werner-Krippendorf.

2) Mitigate a potential memory leak by creating circular netns
   references, while also making the netns semantics a bit more
   robust.

Patch (2) has a "Fixes:" line and should be backported to stable.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f7fb92ac 900575aa
Loading
Loading
Loading
Loading
+27 −31
Original line number Diff line number Diff line
@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
	if (dev_v6)
		dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;

	mutex_lock(&wg->device_update_lock);
	ret = wg_socket_init(wg, wg->incoming_port);
	if (ret < 0)
		return ret;
	mutex_lock(&wg->device_update_lock);
		goto out;
	list_for_each_entry(peer, &wg->peer_list, peer_list) {
		wg_packet_send_staged_packets(peer);
		if (peer->persistent_keepalive_interval)
			wg_packet_send_keepalive(peer);
	}
out:
	mutex_unlock(&wg->device_update_lock);
	return 0;
	return ret;
}

#ifdef CONFIG_PM_SLEEP
@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
	list_del(&wg->device_list);
	rtnl_unlock();
	mutex_lock(&wg->device_update_lock);
	rcu_assign_pointer(wg->creating_net, NULL);
	wg->incoming_port = 0;
	wg_socket_reinit(wg, NULL, NULL);
	/* The final references are cleared in the below calls to destroy_workqueue. */
@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
	skb_queue_purge(&wg->incoming_handshakes);
	free_percpu(dev->tstats);
	free_percpu(wg->incoming_handshakes_worker);
	if (wg->have_creating_net_ref)
		put_net(wg->creating_net);
	kvfree(wg->index_hashtable);
	kvfree(wg->peer_hashtable);
	mutex_unlock(&wg->device_update_lock);

	pr_debug("%s: Interface deleted\n", dev->name);
	pr_debug("%s: Interface destroyed\n", dev->name);
	free_netdev(dev);
}

@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
	struct wg_device *wg = netdev_priv(dev);
	int ret = -ENOMEM;

	wg->creating_net = src_net;
	rcu_assign_pointer(wg->creating_net, src_net);
	init_rwsem(&wg->static_identity.lock);
	mutex_init(&wg->socket_update_lock);
	mutex_init(&wg->device_update_lock);
@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
	.newlink		= wg_newlink,
};

static int wg_netdevice_notification(struct notifier_block *nb,
				     unsigned long action, void *data)
static void wg_netns_pre_exit(struct net *net)
{
	struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
	struct wg_device *wg = netdev_priv(dev);

	ASSERT_RTNL();

	if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
		return 0;
	struct wg_device *wg;

	if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
		put_net(wg->creating_net);
		wg->have_creating_net_ref = false;
	} else if (dev_net(dev) != wg->creating_net &&
		   !wg->have_creating_net_ref) {
		wg->have_creating_net_ref = true;
		get_net(wg->creating_net);
	rtnl_lock();
	list_for_each_entry(wg, &device_list, device_list) {
		if (rcu_access_pointer(wg->creating_net) == net) {
			pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
			netif_carrier_off(wg->dev);
			mutex_lock(&wg->device_update_lock);
			rcu_assign_pointer(wg->creating_net, NULL);
			wg_socket_reinit(wg, NULL, NULL);
			mutex_unlock(&wg->device_update_lock);
		}
	return 0;
	}
	rtnl_unlock();
}

static struct notifier_block netdevice_notifier = {
	.notifier_call = wg_netdevice_notification
static struct pernet_operations pernet_ops = {
	.pre_exit = wg_netns_pre_exit
};

int __init wg_device_init(void)
@@ -429,18 +425,18 @@ int __init wg_device_init(void)
		return ret;
#endif

	ret = register_netdevice_notifier(&netdevice_notifier);
	ret = register_pernet_device(&pernet_ops);
	if (ret)
		goto error_pm;

	ret = rtnl_link_register(&link_ops);
	if (ret)
		goto error_netdevice;
		goto error_pernet;

	return 0;

error_netdevice:
	unregister_netdevice_notifier(&netdevice_notifier);
error_pernet:
	unregister_pernet_device(&pernet_ops);
error_pm:
#ifdef CONFIG_PM_SLEEP
	unregister_pm_notifier(&pm_notifier);
@@ -451,7 +447,7 @@ error_pm:
void wg_device_uninit(void)
{
	rtnl_link_unregister(&link_ops);
	unregister_netdevice_notifier(&netdevice_notifier);
	unregister_pernet_device(&pernet_ops);
#ifdef CONFIG_PM_SLEEP
	unregister_pm_notifier(&pm_notifier);
#endif
+1 −2
Original line number Diff line number Diff line
@@ -40,7 +40,7 @@ struct wg_device {
	struct net_device *dev;
	struct crypt_queue encrypt_queue, decrypt_queue;
	struct sock __rcu *sock4, *sock6;
	struct net *creating_net;
	struct net __rcu *creating_net;
	struct noise_static_identity static_identity;
	struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
	struct workqueue_struct *packet_crypt_wq;
@@ -56,7 +56,6 @@ struct wg_device {
	unsigned int num_peers, device_update_gen;
	u32 fwmark;
	u16 incoming_port;
	bool have_creating_net_ref;
};

int wg_device_init(void);
+9 −5
Original line number Diff line number Diff line
@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
	if (flags & ~__WGDEVICE_F_ALL)
		goto out;

	ret = -EPERM;
	if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
	     info->attrs[WGDEVICE_A_FWMARK]) &&
	    !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
	if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
		struct net *net;
		rcu_read_lock();
		net = rcu_dereference(wg->creating_net);
		ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
		rcu_read_unlock();
		if (ret)
			goto out;
	}

	++wg->device_update_gen;

+2 −2
Original line number Diff line number Diff line
@@ -617,8 +617,8 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
	handshake->remote_index = src->sender_index;
	if ((s64)(handshake->last_initiation_consumption -
	    (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
	initiation_consumption = ktime_get_coarse_boottime_ns();
	if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
		handshake->last_initiation_consumption = initiation_consumption;
	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
	up_write(&handshake->lock);
+18 −7
Original line number Diff line number Diff line
@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)

int wg_socket_init(struct wg_device *wg, u16 port)
{
	struct net *net;
	int ret;
	struct udp_tunnel_sock_cfg cfg = {
		.sk_user_data = wg,
@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
	};
#endif

	rcu_read_lock();
	net = rcu_dereference(wg->creating_net);
	net = net ? maybe_get_net(net) : NULL;
	rcu_read_unlock();
	if (unlikely(!net))
		return -ENONET;

#if IS_ENABLED(CONFIG_IPV6)
retry:
#endif

	ret = udp_sock_create(wg->creating_net, &port4, &new4);
	ret = udp_sock_create(net, &port4, &new4);
	if (ret < 0) {
		pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
		return ret;
		goto out;
	}
	set_sock_opts(new4);
	setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
	setup_udp_tunnel_sock(net, new4, &cfg);

#if IS_ENABLED(CONFIG_IPV6)
	if (ipv6_mod_enabled()) {
		port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
		ret = udp_sock_create(wg->creating_net, &port6, &new6);
		ret = udp_sock_create(net, &port6, &new6);
		if (ret < 0) {
			udp_tunnel_sock_release(new4);
			if (ret == -EADDRINUSE && !port && retries++ < 100)
				goto retry;
			pr_err("%s: Could not create IPv6 socket\n",
			       wg->dev->name);
			return ret;
			goto out;
		}
		set_sock_opts(new6);
		setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
		setup_udp_tunnel_sock(net, new6, &cfg);
	}
#endif

	wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
	return 0;
	ret = 0;
out:
	put_net(net);
	return ret;
}

void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
Loading