Commit f4d4c49b authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'rtnetlink-rework-handler-registration'



Florian Westphal says:

====================
rtnetlink: rework handler (un)registering

Peter Zijlstra reported (referring to commit 019a3169,
"rtnetlink: add reference counting to prevent module unload while dump is in progress"):

 1) it not in fact a refcount, so using refcount_t is silly
 2) there is a distinct lack of memory barriers, so we can easily
    observe the decrement while the msg_handler is still in progress.
 3) waiting with a schedule()/yield() loop is complete crap and subject
    life-locks, imagine doing that rtnl_unregister_all() from a RT task.

In ancient times rtnetlink exposed a statically-sized table with
preset doit/dumpit handlers to be called for a protocol/type pair.

Later the rtnl_register interface was added and the table was allocated
on demand.  Eventually these were also used by modules.

Problem is that nothing prevents module unload while a netlink dump
is in progress.  netlink dumps can be span multiple recv calls and
netlink core saves the to-be-repeated dumper address for later invocation.

To prevent rmmod the netlink core expects callers to pass in the owning
module so a reference can be taken.

So far rtnetlink wasn't doing this, add new interface to pass THIS_MODULE.
Moreover, when converting parts of the rtnetlink handling to rcu this code
gained way too many READ_ONCE spots, remove them and the extra refcounting.

Take a module reference when running dumpit and doit callbacks
and never alter content of rtnl_link structures after they have been
published via rcu_assign_pointer.

Based partially on earlier patch from Peter.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 9753c21f 16feebcf
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -13,10 +13,10 @@ enum rtnl_link_flags {
	RTNL_FLAG_DOIT_UNLOCKED = 1,
};

int __rtnl_register(int protocol, int msgtype,
		    rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
void rtnl_register(int protocol, int msgtype,
		   rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_register_module(struct module *owner, int protocol, int msgtype,
			 rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_unregister(int protocol, int msgtype);
void rtnl_unregister_all(int protocol);

+3 −3
Original line number Diff line number Diff line
@@ -760,9 +760,9 @@ static int br_mdb_del(struct sk_buff *skb, struct nlmsghdr *nlh,

void br_mdb_init(void)
{
	rtnl_register(PF_BRIDGE, RTM_GETMDB, NULL, br_mdb_dump, 0);
	rtnl_register(PF_BRIDGE, RTM_NEWMDB, br_mdb_add, NULL, 0);
	rtnl_register(PF_BRIDGE, RTM_DELMDB, br_mdb_del, NULL, 0);
	rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETMDB, NULL, br_mdb_dump, 0);
	rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWMDB, br_mdb_add, NULL, 0);
	rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELMDB, br_mdb_del, NULL, 0);
}

void br_mdb_uninit(void)
+10 −4
Original line number Diff line number Diff line
@@ -1014,6 +1014,8 @@ static struct pernet_operations cangw_pernet_ops = {

static __init int cgw_module_init(void)
{
	int ret;

	/* sanitize given module parameter */
	max_hops = clamp_t(unsigned int, max_hops, CGW_MIN_HOPS, CGW_MAX_HOPS);

@@ -1031,15 +1033,19 @@ static __init int cgw_module_init(void)
	notifier.notifier_call = cgw_notifier;
	register_netdevice_notifier(&notifier);

	if (__rtnl_register(PF_CAN, RTM_GETROUTE, NULL, cgw_dump_jobs, 0)) {
	ret = rtnl_register_module(THIS_MODULE, PF_CAN, RTM_GETROUTE,
				   NULL, cgw_dump_jobs, 0);
	if (ret) {
		unregister_netdevice_notifier(&notifier);
		kmem_cache_destroy(cgw_cache);
		return -ENOBUFS;
	}

	/* Only the first call to __rtnl_register can fail */
	__rtnl_register(PF_CAN, RTM_NEWROUTE, cgw_create_job, NULL, 0);
	__rtnl_register(PF_CAN, RTM_DELROUTE, cgw_remove_job, NULL, 0);
	/* Only the first call to rtnl_register_module can fail */
	rtnl_register_module(THIS_MODULE, PF_CAN, RTM_NEWROUTE,
			     cgw_create_job, NULL, 0);
	rtnl_register_module(THIS_MODULE, PF_CAN, RTM_DELROUTE,
			     cgw_remove_job, NULL, 0);

	return 0;
}
+172 −98
Original line number Diff line number Diff line
@@ -62,7 +62,9 @@
struct rtnl_link {
	rtnl_doit_func		doit;
	rtnl_dumpit_func	dumpit;
	struct module		*owner;
	unsigned int		flags;
	struct rcu_head		rcu;
};

static DEFINE_MUTEX(rtnl_mutex);
@@ -127,8 +129,7 @@ bool lockdep_rtnl_is_held(void)
EXPORT_SYMBOL(lockdep_rtnl_is_held);
#endif /* #ifdef CONFIG_PROVE_LOCKING */

static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];

static inline int rtm_msgindex(int msgtype)
{
@@ -144,70 +145,125 @@ static inline int rtm_msgindex(int msgtype)
	return msgindex;
}

/**
 * __rtnl_register - Register a rtnetlink message type
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Registers the specified function pointers (at least one of them has
 * to be non-NULL) to be called whenever a request message for the
 * specified protocol family and message type is received.
 *
 * The special protocol family PF_UNSPEC may be used to define fallback
 * function pointers for the case when no entry for the specific protocol
 * family exists.
 *
 * Returns 0 on success or a negative error code.
 */
int __rtnl_register(int protocol, int msgtype,
static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
{
	struct rtnl_link **tab;

	if (protocol >= ARRAY_SIZE(rtnl_msg_handlers))
		protocol = PF_UNSPEC;

	tab = rcu_dereference_rtnl(rtnl_msg_handlers[protocol]);
	if (!tab)
		tab = rcu_dereference_rtnl(rtnl_msg_handlers[PF_UNSPEC]);

	return tab[msgtype];
}

static int rtnl_register_internal(struct module *owner,
				  int protocol, int msgtype,
				  rtnl_doit_func doit, rtnl_dumpit_func dumpit,
				  unsigned int flags)
{
	struct rtnl_link *tab;
	struct rtnl_link **tab, *link, *old;
	int msgindex;
	int ret = -ENOBUFS;

	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
	msgindex = rtm_msgindex(msgtype);

	tab = rcu_dereference_raw(rtnl_msg_handlers[protocol]);
	rtnl_lock();
	tab = rtnl_msg_handlers[protocol];
	if (tab == NULL) {
		tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
		if (tab == NULL)
			return -ENOBUFS;
		tab = kcalloc(RTM_NR_MSGTYPES, sizeof(void *), GFP_KERNEL);
		if (!tab)
			goto unlock;

		/* ensures we see the 0 stores */
		rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
	}

	old = rtnl_dereference(tab[msgindex]);
	if (old) {
		link = kmemdup(old, sizeof(*old), GFP_KERNEL);
		if (!link)
			goto unlock;
	} else {
		link = kzalloc(sizeof(*link), GFP_KERNEL);
		if (!link)
			goto unlock;
	}

	WARN_ON(link->owner && link->owner != owner);
	link->owner = owner;

	WARN_ON(doit && link->doit && link->doit != doit);
	if (doit)
		tab[msgindex].doit = doit;
		link->doit = doit;
	WARN_ON(dumpit && link->dumpit && link->dumpit != dumpit);
	if (dumpit)
		tab[msgindex].dumpit = dumpit;
	tab[msgindex].flags |= flags;
		link->dumpit = dumpit;

	return 0;
	link->flags |= flags;

	/* publish protocol:msgtype */
	rcu_assign_pointer(tab[msgindex], link);
	ret = 0;
	if (old)
		kfree_rcu(old, rcu);
unlock:
	rtnl_unlock();
	return ret;
}

/**
 * rtnl_register_module - Register a rtnetlink message type
 *
 * @owner: module registering the hook (THIS_MODULE)
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Like rtnl_register, but for use by removable modules.
 */
int rtnl_register_module(struct module *owner,
			 int protocol, int msgtype,
			 rtnl_doit_func doit, rtnl_dumpit_func dumpit,
			 unsigned int flags)
{
	return rtnl_register_internal(owner, protocol, msgtype,
				      doit, dumpit, flags);
}
EXPORT_SYMBOL_GPL(__rtnl_register);
EXPORT_SYMBOL_GPL(rtnl_register_module);

/**
 * rtnl_register - Register a rtnetlink message type
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Registers the specified function pointers (at least one of them has
 * to be non-NULL) to be called whenever a request message for the
 * specified protocol family and message type is received.
 *
 * Identical to __rtnl_register() but panics on failure. This is useful
 * as failure of this function is very unlikely, it can only happen due
 * to lack of memory when allocating the chain to store all message
 * handlers for a protocol. Meant for use in init functions where lack
 * of memory implies no sense in continuing.
 * The special protocol family PF_UNSPEC may be used to define fallback
 * function pointers for the case when no entry for the specific protocol
 * family exists.
 */
void rtnl_register(int protocol, int msgtype,
		   rtnl_doit_func doit, rtnl_dumpit_func dumpit,
		   unsigned int flags)
{
	if (__rtnl_register(protocol, msgtype, doit, dumpit, flags) < 0)
		panic("Unable to register rtnetlink message handler, "
		      "protocol = %d, message type = %d\n",
		      protocol, msgtype);
	int err;

	err = rtnl_register_internal(NULL, protocol, msgtype, doit, dumpit,
				     flags);
	if (err)
		pr_err("Unable to register rtnetlink message handler, "
		       "protocol = %d, message type = %d\n", protocol, msgtype);
}
EXPORT_SYMBOL_GPL(rtnl_register);

@@ -220,24 +276,25 @@ EXPORT_SYMBOL_GPL(rtnl_register);
 */
int rtnl_unregister(int protocol, int msgtype)
{
	struct rtnl_link *handlers;
	struct rtnl_link **tab, *link;
	int msgindex;

	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
	msgindex = rtm_msgindex(msgtype);

	rtnl_lock();
	handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
	if (!handlers) {
	tab = rtnl_dereference(rtnl_msg_handlers[protocol]);
	if (!tab) {
		rtnl_unlock();
		return -ENOENT;
	}

	handlers[msgindex].doit = NULL;
	handlers[msgindex].dumpit = NULL;
	handlers[msgindex].flags = 0;
	link = tab[msgindex];
	rcu_assign_pointer(tab[msgindex], NULL);
	rtnl_unlock();

	kfree_rcu(link, rcu);

	return 0;
}
EXPORT_SYMBOL_GPL(rtnl_unregister);
@@ -251,20 +308,27 @@ EXPORT_SYMBOL_GPL(rtnl_unregister);
 */
void rtnl_unregister_all(int protocol)
{
	struct rtnl_link *handlers;
	struct rtnl_link **tab, *link;
	int msgindex;

	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);

	rtnl_lock();
	handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
	tab = rtnl_msg_handlers[protocol];
	RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
	for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) {
		link = tab[msgindex];
		if (!link)
			continue;

		rcu_assign_pointer(tab[msgindex], NULL);
		kfree_rcu(link, rcu);
	}
	rtnl_unlock();

	synchronize_net();

	while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
		schedule();
	kfree(handlers);
	kfree(tab);
}
EXPORT_SYMBOL_GPL(rtnl_unregister_all);

@@ -2973,18 +3037,26 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb)
		s_idx = 1;

	for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
		struct rtnl_link **tab;
		int type = cb->nlh->nlmsg_type-RTM_BASE;
		struct rtnl_link *handlers;
		struct rtnl_link *link;
		rtnl_dumpit_func dumpit;

		if (idx < s_idx || idx == PF_PACKET)
			continue;

		handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
		if (!handlers)
		if (type < 0 || type >= RTM_NR_MSGTYPES)
			continue;

		dumpit = READ_ONCE(handlers[type].dumpit);
		tab = rcu_dereference_rtnl(rtnl_msg_handlers[idx]);
		if (!tab)
			continue;

		link = tab[type];
		if (!link)
			continue;

		dumpit = link->dumpit;
		if (!dumpit)
			continue;

@@ -4314,7 +4386,8 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
			     struct netlink_ext_ack *extack)
{
	struct net *net = sock_net(skb->sk);
	struct rtnl_link *handlers;
	struct rtnl_link *link;
	struct module *owner;
	int err = -EOPNOTSUPP;
	rtnl_doit_func doit;
	unsigned int flags;
@@ -4338,79 +4411,85 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
	if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
		return -EPERM;

	if (family >= ARRAY_SIZE(rtnl_msg_handlers))
		family = PF_UNSPEC;

	rcu_read_lock();
	handlers = rcu_dereference(rtnl_msg_handlers[family]);
	if (!handlers) {
		family = PF_UNSPEC;
		handlers = rcu_dereference(rtnl_msg_handlers[family]);
	}

	if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
		struct sock *rtnl;
		rtnl_dumpit_func dumpit;
		u16 min_dump_alloc = 0;

		dumpit = READ_ONCE(handlers[type].dumpit);
		if (!dumpit) {
		link = rtnl_get_link(family, type);
		if (!link || !link->dumpit) {
			family = PF_UNSPEC;
			handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
			if (!handlers)
				goto err_unlock;

			dumpit = READ_ONCE(handlers[type].dumpit);
			if (!dumpit)
			link = rtnl_get_link(family, type);
			if (!link || !link->dumpit)
				goto err_unlock;
		}

		refcount_inc(&rtnl_msg_handlers_ref[family]);
		owner = link->owner;
		dumpit = link->dumpit;

		if (type == RTM_GETLINK - RTM_BASE)
			min_dump_alloc = rtnl_calcit(skb, nlh);

		err = 0;
		/* need to do this before rcu_read_unlock() */
		if (!try_module_get(owner))
			err = -EPROTONOSUPPORT;

		rcu_read_unlock();

		rtnl = net->rtnl;
		{
		if (err == 0) {
			struct netlink_dump_control c = {
				.dump		= dumpit,
				.min_dump_alloc	= min_dump_alloc,
				.module		= owner,
			};
			err = netlink_dump_start(rtnl, skb, nlh, &c);
			/* netlink_dump_start() will keep a reference on
			 * module if dump is still in progress.
			 */
			module_put(owner);
		}
		refcount_dec(&rtnl_msg_handlers_ref[family]);
		return err;
	}

	doit = READ_ONCE(handlers[type].doit);
	if (!doit) {
	link = rtnl_get_link(family, type);
	if (!link || !link->doit) {
		family = PF_UNSPEC;
		handlers = rcu_dereference(rtnl_msg_handlers[family]);
		link = rtnl_get_link(PF_UNSPEC, type);
		if (!link || !link->doit)
			goto out_unlock;
	}

	owner = link->owner;
	if (!try_module_get(owner)) {
		err = -EPROTONOSUPPORT;
		goto out_unlock;
	}

	flags = READ_ONCE(handlers[type].flags);
	flags = link->flags;
	if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
		refcount_inc(&rtnl_msg_handlers_ref[family]);
		doit = READ_ONCE(handlers[type].doit);
		doit = link->doit;
		rcu_read_unlock();
		if (doit)
			err = doit(skb, nlh, extack);
		refcount_dec(&rtnl_msg_handlers_ref[family]);
		module_put(owner);
		return err;
	}

	rcu_read_unlock();

	rtnl_lock();
	handlers = rtnl_dereference(rtnl_msg_handlers[family]);
	if (handlers) {
		doit = READ_ONCE(handlers[type].doit);
		if (doit)
			err = doit(skb, nlh, extack);
	}
	link = rtnl_get_link(family, type);
	if (link && link->doit)
		err = link->doit(skb, nlh, extack);
	rtnl_unlock();

	module_put(owner);

	return err;

out_unlock:
	rcu_read_unlock();
	return err;

err_unlock:
@@ -4498,11 +4577,6 @@ static struct pernet_operations rtnetlink_net_ops = {

void __init rtnetlink_init(void)
{
	int i;

	for (i = 0; i < ARRAY_SIZE(rtnl_msg_handlers_ref); i++)
		refcount_set(&rtnl_msg_handlers_ref[i], 1);

	if (register_pernet_subsys(&rtnetlink_net_ops))
		panic("rtnetlink_init: cannot initialize rtnetlink\n");

+6 −3
Original line number Diff line number Diff line
@@ -1418,9 +1418,12 @@ void __init dn_dev_init(void)

	dn_dev_devices_on();

	rtnl_register(PF_DECnet, RTM_NEWADDR, dn_nl_newaddr, NULL, 0);
	rtnl_register(PF_DECnet, RTM_DELADDR, dn_nl_deladdr, NULL, 0);
	rtnl_register(PF_DECnet, RTM_GETADDR, NULL, dn_nl_dump_ifaddr, 0);
	rtnl_register_module(THIS_MODULE, PF_DECnet, RTM_NEWADDR,
			     dn_nl_newaddr, NULL, 0);
	rtnl_register_module(THIS_MODULE, PF_DECnet, RTM_DELADDR,
			     dn_nl_deladdr, NULL, 0);
	rtnl_register_module(THIS_MODULE, PF_DECnet, RTM_GETADDR,
			     NULL, dn_nl_dump_ifaddr, 0);

	proc_create("decnet_dev", S_IRUGO, init_net.proc_net, &dn_dev_seq_fops);

Loading