Commit df1c6316 authored by David Ahern's avatar David Ahern Committed by David S. Miller
Browse files

net: mpls: Limit memory allocation for mpls_route



Limit memory allocation size for mpls_route to 4096.

Signed-off-by: default avatarDavid Ahern <dsa@cumulusnetworks.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 59b20966
Loading
Loading
Loading
Loading
+21 −10
Original line number Diff line number Diff line
@@ -26,6 +26,9 @@

#define MAX_NEW_LABELS 2

/* max memory we will use for mpls_route */
#define MAX_MPLS_ROUTE_MEM	4096

/* Maximum number of labels to look ahead at when selecting a path of
 * a multipath route
 */
@@ -477,14 +480,20 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
{
	u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen);
	struct mpls_route *rt;
	size_t size;

	size = sizeof(*rt) + num_nh * nh_size;
	if (size > MAX_MPLS_ROUTE_MEM)
		return ERR_PTR(-EINVAL);

	rt = kzalloc(size, GFP_KERNEL);
	if (!rt)
		return ERR_PTR(-ENOMEM);

	rt = kzalloc(sizeof(*rt) + num_nh * nh_size, GFP_KERNEL);
	if (rt) {
	rt->rt_nhn = num_nh;
	rt->rt_nhn_alive = num_nh;
	rt->rt_nh_size = nh_size;
	rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
	}

	return rt;
}
@@ -898,8 +907,10 @@ static int mpls_route_add(struct mpls_route_config *cfg)

	err = -ENOMEM;
	rt = mpls_rt_alloc(nhs, max_via_alen, MAX_NEW_LABELS);
	if (!rt)
	if (IS_ERR(rt)) {
		err = PTR_ERR(rt);
		goto errout;
	}

	rt->rt_protocol = cfg->rc_protocol;
	rt->rt_payload_type = cfg->rc_payload_type;
@@ -1970,7 +1981,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
	if (limit > MPLS_LABEL_IPV4NULL) {
		struct net_device *lo = net->loopback_dev;
		rt0 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
		if (!rt0)
		if (IS_ERR(rt0))
			goto nort0;
		RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
		rt0->rt_protocol = RTPROT_KERNEL;
@@ -1984,7 +1995,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
	if (limit > MPLS_LABEL_IPV6NULL) {
		struct net_device *lo = net->loopback_dev;
		rt2 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
		if (!rt2)
		if (IS_ERR(rt2))
			goto nort2;
		RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
		rt2->rt_protocol = RTPROT_KERNEL;