Commit 3f0465a9 authored by Pablo Neira Ayuso's avatar Pablo Neira Ayuso
Browse files

netfilter: nf_tables: dynamically allocate hooks per net_device in flowtables



Use a list of hooks per device instead an array.

Signed-off-by: default avatarPablo Neira Ayuso <pablo@netfilter.org>
parent 71a8a63b
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -963,6 +963,12 @@ struct nft_stats {
	struct u64_stats_sync	syncp;
};

struct nft_hook {
	struct list_head	list;
	struct nf_hook_ops	ops;
	struct rcu_head		rcu;
};

/**
 *	struct nft_base_chain - nf_tables base chain
 *
@@ -1173,7 +1179,7 @@ struct nft_flowtable {
					use:30;
	u64				handle;
	/* runtime data below here */
	struct nf_hook_ops		*ops ____cacheline_aligned;
	struct list_head		hook_list ____cacheline_aligned;
	struct nf_flowtable		data;
};

+151 −102
Original line number Diff line number Diff line
@@ -1508,6 +1508,76 @@ static void nf_tables_chain_destroy(struct nft_ctx *ctx)
	}
}

static struct nft_hook *nft_netdev_hook_alloc(struct net *net,
					      const struct nlattr *attr)
{
	struct net_device *dev;
	char ifname[IFNAMSIZ];
	struct nft_hook *hook;
	int err;

	hook = kmalloc(sizeof(struct nft_hook), GFP_KERNEL);
	if (!hook) {
		err = -ENOMEM;
		goto err_hook_alloc;
	}

	nla_strlcpy(ifname, attr, IFNAMSIZ);
	dev = __dev_get_by_name(net, ifname);
	if (!dev) {
		err = -ENOENT;
		goto err_hook_dev;
	}
	hook->ops.dev = dev;

	return hook;

err_hook_dev:
	kfree(hook);
err_hook_alloc:
	return ERR_PTR(err);
}

static int nf_tables_parse_netdev_hooks(struct net *net,
					const struct nlattr *attr,
					struct list_head *hook_list)
{
	struct nft_hook *hook, *next;
	const struct nlattr *tmp;
	int rem, n = 0, err;

	nla_for_each_nested(tmp, attr, rem) {
		if (nla_type(tmp) != NFTA_DEVICE_NAME) {
			err = -EINVAL;
			goto err_hook;
		}

		hook = nft_netdev_hook_alloc(net, tmp);
		if (IS_ERR(hook)) {
			err = PTR_ERR(hook);
			goto err_hook;
		}
		list_add_tail(&hook->list, hook_list);
		n++;

		if (n == NFT_FLOWTABLE_DEVICE_MAX) {
			err = -EFBIG;
			goto err_hook;
		}
	}
	if (!n)
		return -EINVAL;

	return 0;

err_hook:
	list_for_each_entry_safe(hook, next, hook_list, list) {
		list_del(&hook->list);
		kfree(hook);
	}
	return err;
}

struct nft_chain_hook {
	u32				num;
	s32				priority;
@@ -5628,43 +5698,6 @@ nft_flowtable_lookup_byhandle(const struct nft_table *table,
       return ERR_PTR(-ENOENT);
}

static int nf_tables_parse_devices(const struct nft_ctx *ctx,
				   const struct nlattr *attr,
				   struct net_device *dev_array[], int *len)
{
	const struct nlattr *tmp;
	struct net_device *dev;
	char ifname[IFNAMSIZ];
	int rem, n = 0, err;

	nla_for_each_nested(tmp, attr, rem) {
		if (nla_type(tmp) != NFTA_DEVICE_NAME) {
			err = -EINVAL;
			goto err1;
		}

		nla_strlcpy(ifname, tmp, IFNAMSIZ);
		dev = __dev_get_by_name(ctx->net, ifname);
		if (!dev) {
			err = -ENOENT;
			goto err1;
		}

		dev_array[n++] = dev;
		if (n == NFT_FLOWTABLE_DEVICE_MAX) {
			err = -EFBIG;
			goto err1;
		}
	}
	if (!len)
		return -EINVAL;

	err = 0;
err1:
	*len = n;
	return err;
}

static const struct nla_policy nft_flowtable_hook_policy[NFTA_FLOWTABLE_HOOK_MAX + 1] = {
	[NFTA_FLOWTABLE_HOOK_NUM]	= { .type = NLA_U32 },
	[NFTA_FLOWTABLE_HOOK_PRIORITY]	= { .type = NLA_U32 },
@@ -5675,11 +5708,10 @@ static int nf_tables_flowtable_parse_hook(const struct nft_ctx *ctx,
					  const struct nlattr *attr,
					  struct nft_flowtable *flowtable)
{
	struct net_device *dev_array[NFT_FLOWTABLE_DEVICE_MAX];
	struct nlattr *tb[NFTA_FLOWTABLE_HOOK_MAX + 1];
	struct nf_hook_ops *ops;
	struct nft_hook *hook;
	int hooknum, priority;
	int err, n = 0, i;
	int err;

	err = nla_parse_nested_deprecated(tb, NFTA_FLOWTABLE_HOOK_MAX, attr,
					  nft_flowtable_hook_policy, NULL);
@@ -5697,27 +5729,21 @@ static int nf_tables_flowtable_parse_hook(const struct nft_ctx *ctx,

	priority = ntohl(nla_get_be32(tb[NFTA_FLOWTABLE_HOOK_PRIORITY]));

	err = nf_tables_parse_devices(ctx, tb[NFTA_FLOWTABLE_HOOK_DEVS],
				      dev_array, &n);
	err = nf_tables_parse_netdev_hooks(ctx->net,
					   tb[NFTA_FLOWTABLE_HOOK_DEVS],
					   &flowtable->hook_list);
	if (err < 0)
		return err;

	ops = kcalloc(n, sizeof(struct nf_hook_ops), GFP_KERNEL);
	if (!ops)
		return -ENOMEM;

	flowtable->hooknum		= hooknum;
	flowtable->data.priority	= priority;
	flowtable->ops			= ops;
	flowtable->ops_len		= n;

	for (i = 0; i < n; i++) {
		flowtable->ops[i].pf		= NFPROTO_NETDEV;
		flowtable->ops[i].hooknum	= hooknum;
		flowtable->ops[i].priority	= priority;
		flowtable->ops[i].priv		= &flowtable->data;
		flowtable->ops[i].hook		= flowtable->data.type->hook;
		flowtable->ops[i].dev		= dev_array[i];
	list_for_each_entry(hook, &flowtable->hook_list, list) {
		hook->ops.pf		= NFPROTO_NETDEV;
		hook->ops.hooknum	= hooknum;
		hook->ops.priority	= priority;
		hook->ops.priv		= &flowtable->data;
		hook->ops.hook		= flowtable->data.type->hook;
	}

	return err;
@@ -5757,14 +5783,51 @@ nft_flowtable_type_get(struct net *net, u8 family)
static void nft_unregister_flowtable_net_hooks(struct net *net,
					       struct nft_flowtable *flowtable)
{
	int i;
	struct nft_hook *hook;

	for (i = 0; i < flowtable->ops_len; i++) {
		if (!flowtable->ops[i].dev)
			continue;
	list_for_each_entry(hook, &flowtable->hook_list, list)
		nf_unregister_net_hook(net, &hook->ops);
}

static int nft_register_flowtable_net_hooks(struct net *net,
					    struct nft_table *table,
					    struct nft_flowtable *flowtable)
{
	struct nft_hook *hook, *hook2, *next;
	struct nft_flowtable *ft;
	int err, i = 0;

	list_for_each_entry(hook, &flowtable->hook_list, list) {
		list_for_each_entry(ft, &table->flowtables, list) {
			list_for_each_entry(hook2, &ft->hook_list, list) {
				if (hook->ops.dev == hook2->ops.dev &&
				    hook->ops.pf == hook2->ops.pf) {
					err = -EBUSY;
					goto err_unregister_net_hooks;
				}
			}
		}

		err = nf_register_net_hook(net, &hook->ops);
		if (err < 0)
			goto err_unregister_net_hooks;

		i++;
	}

		nf_unregister_net_hook(net, &flowtable->ops[i]);
	return 0;

err_unregister_net_hooks:
	list_for_each_entry_safe(hook, next, &flowtable->hook_list, list) {
		if (i-- <= 0)
			break;

		nf_unregister_net_hook(net, &hook->ops);
		list_del_rcu(&hook->list);
		kfree_rcu(hook, rcu);
	}

	return err;
}

static int nf_tables_newflowtable(struct net *net, struct sock *nlsk,
@@ -5775,12 +5838,13 @@ static int nf_tables_newflowtable(struct net *net, struct sock *nlsk,
{
	const struct nfgenmsg *nfmsg = nlmsg_data(nlh);
	const struct nf_flowtable_type *type;
	struct nft_flowtable *flowtable, *ft;
	u8 genmask = nft_genmask_next(net);
	int family = nfmsg->nfgen_family;
	struct nft_flowtable *flowtable;
	struct nft_hook *hook, *next;
	struct nft_table *table;
	struct nft_ctx ctx;
	int err, i, k;
	int err;

	if (!nla[NFTA_FLOWTABLE_TABLE] ||
	    !nla[NFTA_FLOWTABLE_NAME] ||
@@ -5819,6 +5883,7 @@ static int nf_tables_newflowtable(struct net *net, struct sock *nlsk,

	flowtable->table = table;
	flowtable->handle = nf_tables_alloc_handle(table);
	INIT_LIST_HEAD(&flowtable->hook_list);

	flowtable->name = nla_strdup(nla[NFTA_FLOWTABLE_NAME], GFP_KERNEL);
	if (!flowtable->name) {
@@ -5842,43 +5907,24 @@ static int nf_tables_newflowtable(struct net *net, struct sock *nlsk,
	if (err < 0)
		goto err4;

	for (i = 0; i < flowtable->ops_len; i++) {
		if (!flowtable->ops[i].dev)
			continue;

		list_for_each_entry(ft, &table->flowtables, list) {
			for (k = 0; k < ft->ops_len; k++) {
				if (!ft->ops[k].dev)
					continue;

				if (flowtable->ops[i].dev == ft->ops[k].dev &&
				    flowtable->ops[i].pf == ft->ops[k].pf) {
					err = -EBUSY;
					goto err5;
				}
			}
		}

		err = nf_register_net_hook(net, &flowtable->ops[i]);
	err = nft_register_flowtable_net_hooks(ctx.net, table, flowtable);
	if (err < 0)
			goto err5;
	}
		goto err4;

	err = nft_trans_flowtable_add(&ctx, NFT_MSG_NEWFLOWTABLE, flowtable);
	if (err < 0)
		goto err6;
		goto err5;

	list_add_tail_rcu(&flowtable->list, &table->flowtables);
	table->use++;

	return 0;
err6:
	i = flowtable->ops_len;
err5:
	for (k = i - 1; k >= 0; k--)
		nf_unregister_net_hook(net, &flowtable->ops[k]);

	kfree(flowtable->ops);
	list_for_each_entry_safe(hook, next, &flowtable->hook_list, list) {
		nf_unregister_net_hook(net, &hook->ops);
		list_del_rcu(&hook->list);
		kfree_rcu(hook, rcu);
	}
err4:
	flowtable->data.type->free(&flowtable->data);
err3:
@@ -5945,8 +5991,8 @@ static int nf_tables_fill_flowtable_info(struct sk_buff *skb, struct net *net,
{
	struct nlattr *nest, *nest_devs;
	struct nfgenmsg *nfmsg;
	struct nft_hook *hook;
	struct nlmsghdr *nlh;
	int i;

	event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event);
	nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct nfgenmsg), flags);
@@ -5976,11 +6022,8 @@ static int nf_tables_fill_flowtable_info(struct sk_buff *skb, struct net *net,
	if (!nest_devs)
		goto nla_put_failure;

	for (i = 0; i < flowtable->ops_len; i++) {
		const struct net_device *dev = READ_ONCE(flowtable->ops[i].dev);

		if (dev &&
		    nla_put_string(skb, NFTA_DEVICE_NAME, dev->name))
	list_for_each_entry_rcu(hook, &flowtable->hook_list, list) {
		if (nla_put_string(skb, NFTA_DEVICE_NAME, hook->ops.dev->name))
			goto nla_put_failure;
	}
	nla_nest_end(skb, nest_devs);
@@ -6171,7 +6214,12 @@ err:

static void nf_tables_flowtable_destroy(struct nft_flowtable *flowtable)
{
	kfree(flowtable->ops);
	struct nft_hook *hook, *next;

	list_for_each_entry_safe(hook, next, &flowtable->hook_list, list) {
		list_del_rcu(&hook->list);
		kfree(hook);
	}
	kfree(flowtable->name);
	flowtable->data.type->free(&flowtable->data);
	module_put(flowtable->data.type->owner);
@@ -6211,14 +6259,15 @@ nla_put_failure:
static void nft_flowtable_event(unsigned long event, struct net_device *dev,
				struct nft_flowtable *flowtable)
{
	int i;
	struct nft_hook *hook;

	for (i = 0; i < flowtable->ops_len; i++) {
		if (flowtable->ops[i].dev != dev)
	list_for_each_entry(hook, &flowtable->hook_list, list) {
		if (hook->ops.dev != dev)
			continue;

		nf_unregister_net_hook(dev_net(dev), &flowtable->ops[i]);
		flowtable->ops[i].dev = NULL;
		nf_unregister_net_hook(dev_net(dev), &hook->ops);
		list_del_rcu(&hook->list);
		kfree_rcu(hook, rcu);
		break;
	}
}