Commit e4202511 authored by Florian Westphal's avatar Florian Westphal Committed by David S. Miller
Browse files

rtnetlink: get reference on module before invoking handlers



Add yet another rtnl_register function.  It will be used by modules
that can be removed.

The passed module struct is used to prevent module unload while
a netlink dump is in progress or when a DOIT_UNLOCKED doit callback
is called.

Cc: Peter Zijlstra <peterz@infradead.org>
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent addf9b90
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -17,6 +17,8 @@ int __rtnl_register(int protocol, int msgtype,
		    rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
void rtnl_register(int protocol, int msgtype,
		   rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_register_module(struct module *owner, int protocol, int msgtype,
			 rtnl_doit_func, rtnl_dumpit_func, unsigned int flags);
int rtnl_unregister(int protocol, int msgtype);
void rtnl_unregister_all(int protocol);

+78 −35
Original line number Diff line number Diff line
@@ -62,6 +62,7 @@
struct rtnl_link {
	rtnl_doit_func		doit;
	rtnl_dumpit_func	dumpit;
	struct module		*owner;
	unsigned int		flags;
	struct rcu_head		rcu;
};
@@ -129,7 +130,6 @@ EXPORT_SYMBOL(lockdep_rtnl_is_held);
#endif /* #ifdef CONFIG_PROVE_LOCKING */

static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];

static inline int rtm_msgindex(int msgtype)
{
@@ -159,25 +159,8 @@ static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
	return tab[msgtype];
}

/**
 * __rtnl_register - Register a rtnetlink message type
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Registers the specified function pointers (at least one of them has
 * to be non-NULL) to be called whenever a request message for the
 * specified protocol family and message type is received.
 *
 * The special protocol family PF_UNSPEC may be used to define fallback
 * function pointers for the case when no entry for the specific protocol
 * family exists.
 *
 * Returns 0 on success or a negative error code.
 */
int __rtnl_register(int protocol, int msgtype,
static int rtnl_register_internal(struct module *owner,
				  int protocol, int msgtype,
				  rtnl_doit_func doit, rtnl_dumpit_func dumpit,
				  unsigned int flags)
{
@@ -210,6 +193,9 @@ int __rtnl_register(int protocol, int msgtype,
			goto unlock;
	}

	WARN_ON(link->owner && link->owner != owner);
	link->owner = owner;

	WARN_ON(doit && link->doit && link->doit != doit);
	if (doit)
		link->doit = doit;
@@ -228,6 +214,54 @@ unlock:
	rtnl_unlock();
	return ret;
}

/**
 * rtnl_register_module - Register a rtnetlink message type
 *
 * @owner: module registering the hook (THIS_MODULE)
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Like rtnl_register, but for use by removable modules.
 */
int rtnl_register_module(struct module *owner,
			 int protocol, int msgtype,
			 rtnl_doit_func doit, rtnl_dumpit_func dumpit,
			 unsigned int flags)
{
	return rtnl_register_internal(owner, protocol, msgtype,
				      doit, dumpit, flags);
}
EXPORT_SYMBOL_GPL(rtnl_register_module);

/**
 * __rtnl_register - Register a rtnetlink message type
 * @protocol: Protocol family or PF_UNSPEC
 * @msgtype: rtnetlink message type
 * @doit: Function pointer called for each request message
 * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message
 * @flags: rtnl_link_flags to modifiy behaviour of doit/dumpit functions
 *
 * Registers the specified function pointers (at least one of them has
 * to be non-NULL) to be called whenever a request message for the
 * specified protocol family and message type is received.
 *
 * The special protocol family PF_UNSPEC may be used to define fallback
 * function pointers for the case when no entry for the specific protocol
 * family exists.
 *
 * Returns 0 on success or a negative error code.
 */
int __rtnl_register(int protocol, int msgtype,
		    rtnl_doit_func doit, rtnl_dumpit_func dumpit,
		    unsigned int flags)
{
	return rtnl_register_internal(NULL, protocol, msgtype,
				      doit, dumpit, flags);
}
EXPORT_SYMBOL_GPL(__rtnl_register);

/**
@@ -311,8 +345,6 @@ void rtnl_unregister_all(int protocol)

	synchronize_net();

	while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
		schedule();
	kfree(tab);
}
EXPORT_SYMBOL_GPL(rtnl_unregister_all);
@@ -4372,6 +4404,7 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
{
	struct net *net = sock_net(skb->sk);
	struct rtnl_link *link;
	struct module *owner;
	int err = -EOPNOTSUPP;
	rtnl_doit_func doit;
	unsigned int flags;
@@ -4408,24 +4441,32 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
			if (!link || !link->dumpit)
				goto err_unlock;
		}
		owner = link->owner;
		dumpit = link->dumpit;

		refcount_inc(&rtnl_msg_handlers_ref[family]);

		if (type == RTM_GETLINK - RTM_BASE)
			min_dump_alloc = rtnl_calcit(skb, nlh);

		err = 0;
		/* need to do this before rcu_read_unlock() */
		if (!try_module_get(owner))
			err = -EPROTONOSUPPORT;

		rcu_read_unlock();

		rtnl = net->rtnl;
		{
		if (err == 0) {
			struct netlink_dump_control c = {
				.dump		= dumpit,
				.min_dump_alloc	= min_dump_alloc,
				.module		= owner,
			};
			err = netlink_dump_start(rtnl, skb, nlh, &c);
			/* netlink_dump_start() will keep a reference on
			 * module if dump is still in progress.
			 */
			module_put(owner);
		}
		refcount_dec(&rtnl_msg_handlers_ref[family]);
		return err;
	}

@@ -4437,14 +4478,19 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
			goto out_unlock;
	}

	owner = link->owner;
	if (!try_module_get(owner)) {
		err = -EPROTONOSUPPORT;
		goto out_unlock;
	}

	flags = link->flags;
	if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
		refcount_inc(&rtnl_msg_handlers_ref[family]);
		doit = link->doit;
		rcu_read_unlock();
		if (doit)
			err = doit(skb, nlh, extack);
		refcount_dec(&rtnl_msg_handlers_ref[family]);
		module_put(owner);
		return err;
	}
	rcu_read_unlock();
@@ -4455,6 +4501,8 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
		err = link->doit(skb, nlh, extack);
	rtnl_unlock();

	module_put(owner);

	return err;

out_unlock:
@@ -4546,11 +4594,6 @@ static struct pernet_operations rtnetlink_net_ops = {

void __init rtnetlink_init(void)
{
	int i;

	for (i = 0; i < ARRAY_SIZE(rtnl_msg_handlers_ref); i++)
		refcount_set(&rtnl_msg_handlers_ref[i], 1);

	if (register_pernet_subsys(&rtnetlink_net_ops))
		panic("rtnetlink_init: cannot initialize rtnetlink\n");