Commit 5682d393 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov
Browse files

inet_diag: Refactor inet_sk_diag_fill(), dump(), and dump_one()



In a latter patch, there is a need to update "cb->min_dump_alloc"
in inet_sk_diag_fill() as it learns the diffierent bpf_sk_storages
stored in a sk while dumping all sk(s) (e.g. tcp_hashinfo).

The inet_sk_diag_fill() currently does not take the "cb" as an argument.
One of the reason is inet_sk_diag_fill() is used by both dump_one()
and dump() (which belong to the "struct inet_diag_handler".  The dump_one()
interface does not pass the "cb" along.

This patch is to make dump_one() pass a "cb".  The "cb" is created in
inet_diag_cmd_exact().  The "nlh" and "in_skb" are stored in "cb" as
the dump() interface does.  The total number of args in
inet_sk_diag_fill() is also cut from 10 to 7 and
that helps many callers to pass fewer args.

In particular,
"struct user_namespace *user_ns", "u32 pid", and "u32 seq"
can be replaced by accessing "cb->nlh" and "cb->skb".

A similar argument reduction is also made to
inet_twsk_diag_fill() and inet_req_diag_fill().

inet_csk_diag_dump() and inet_csk_diag_fill() are also removed.
They are mostly equivalent to inet_sk_diag_fill().  Their repeated
usages are very limited.  Thus, inet_sk_diag_fill() is directly used
in those occasions.

Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarSong Liu <songliubraving@fb.com>
Link: https://lore.kernel.org/bpf/20200225230409.1975173-1-kafai@fb.com
parent d7f10df8
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -18,8 +18,7 @@ struct inet_diag_handler {
				const struct inet_diag_req_v2 *r,
				struct nlattr *bc);

	int		(*dump_one)(struct sk_buff *in_skb,
				    const struct nlmsghdr *nlh,
	int		(*dump_one)(struct netlink_callback *cb,
				    const struct inet_diag_req_v2 *req);

	void		(*idiag_get_info)(struct sock *sk,
@@ -42,16 +41,15 @@ struct inet_diag_handler {

struct inet_connection_sock;
int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
		      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
		      struct user_namespace *user_ns,
		      u32 pid, u32 seq, u16 nlmsg_flags,
		      const struct nlmsghdr *unlh, bool net_admin);
		      struct sk_buff *skb, struct netlink_callback *cb,
		      const struct inet_diag_req_v2 *req,
		      u16 nlmsg_flags, bool net_admin);
void inet_diag_dump_icsk(struct inet_hashinfo *h, struct sk_buff *skb,
			 struct netlink_callback *cb,
			 const struct inet_diag_req_v2 *r,
			 struct nlattr *bc);
int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
			    struct sk_buff *in_skb, const struct nlmsghdr *nlh,
			    struct netlink_callback *cb,
			    const struct inet_diag_req_v2 *req);

struct sock *inet_diag_find_one_icsk(struct net *net,
+2 −3
Original line number Diff line number Diff line
@@ -51,11 +51,10 @@ static void dccp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
	inet_diag_dump_icsk(&dccp_hashinfo, skb, cb, r, bc);
}

static int dccp_diag_dump_one(struct sk_buff *in_skb,
			      const struct nlmsghdr *nlh,
static int dccp_diag_dump_one(struct netlink_callback *cb,
			      const struct inet_diag_req_v2 *req)
{
	return inet_diag_dump_one_icsk(&dccp_hashinfo, in_skb, nlh, req);
	return inet_diag_dump_one_icsk(&dccp_hashinfo, cb, req);
}

static const struct inet_diag_handler dccp_diag_handler = {
+44 −72
Original line number Diff line number Diff line
@@ -157,11 +157,9 @@ errout:
EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);

int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
		      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
		      struct user_namespace *user_ns,
		      u32 portid, u32 seq, u16 nlmsg_flags,
		      const struct nlmsghdr *unlh,
		      bool net_admin)
		      struct sk_buff *skb, struct netlink_callback *cb,
		      const struct inet_diag_req_v2 *req,
		      u16 nlmsg_flags, bool net_admin)
{
	const struct tcp_congestion_ops *ca_ops;
	const struct inet_diag_handler *handler;
@@ -174,8 +172,8 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
	handler = inet_diag_table[req->sdiag_protocol];
	BUG_ON(!handler);

	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
			nlmsg_flags);
	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
	if (!nlh)
		return -EMSGSIZE;

@@ -187,7 +185,9 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
	r->idiag_timer = 0;
	r->idiag_retrans = 0;

	if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
				     sk_user_ns(NETLINK_CB(cb->skb).sk),
				     net_admin))
		goto errout;

	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
@@ -312,30 +312,19 @@ errout:
}
EXPORT_SYMBOL_GPL(inet_sk_diag_fill);

static int inet_csk_diag_fill(struct sock *sk,
			      struct sk_buff *skb,
			      const struct inet_diag_req_v2 *req,
			      struct user_namespace *user_ns,
			      u32 portid, u32 seq, u16 nlmsg_flags,
			      const struct nlmsghdr *unlh,
			      bool net_admin)
{
	return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
				 portid, seq, nlmsg_flags, unlh, net_admin);
}

static int inet_twsk_diag_fill(struct sock *sk,
			       struct sk_buff *skb,
			       u32 portid, u32 seq, u16 nlmsg_flags,
			       const struct nlmsghdr *unlh)
			       struct netlink_callback *cb,
			       u16 nlmsg_flags)
{
	struct inet_timewait_sock *tw = inet_twsk(sk);
	struct inet_diag_msg *r;
	struct nlmsghdr *nlh;
	long tmo;

	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
			nlmsg_flags);
	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
			cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
			sizeof(*r), nlmsg_flags);
	if (!nlh)
		return -EMSGSIZE;

@@ -359,16 +348,16 @@ static int inet_twsk_diag_fill(struct sock *sk,
}

static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
			      u32 portid, u32 seq, u16 nlmsg_flags,
			      const struct nlmsghdr *unlh, bool net_admin)
			      struct netlink_callback *cb,
			      u16 nlmsg_flags, bool net_admin)
{
	struct request_sock *reqsk = inet_reqsk(sk);
	struct inet_diag_msg *r;
	struct nlmsghdr *nlh;
	long tmo;

	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
			nlmsg_flags);
	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
	if (!nlh)
		return -EMSGSIZE;

@@ -397,21 +386,18 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
}

static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
			struct netlink_callback *cb,
			const struct inet_diag_req_v2 *r,
			struct user_namespace *user_ns,
			u32 portid, u32 seq, u16 nlmsg_flags,
			const struct nlmsghdr *unlh, bool net_admin)
			u16 nlmsg_flags, bool net_admin)
{
	if (sk->sk_state == TCP_TIME_WAIT)
		return inet_twsk_diag_fill(sk, skb, portid, seq,
					   nlmsg_flags, unlh);
		return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);

	if (sk->sk_state == TCP_NEW_SYN_RECV)
		return inet_req_diag_fill(sk, skb, portid, seq,
					  nlmsg_flags, unlh, net_admin);
		return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);

	return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
				  nlmsg_flags, unlh, net_admin);
	return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
				 net_admin);
}

struct sock *inet_diag_find_one_icsk(struct net *net,
@@ -459,10 +445,10 @@ struct sock *inet_diag_find_one_icsk(struct net *net,
EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);

int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
			    struct sk_buff *in_skb,
			    const struct nlmsghdr *nlh,
			    struct netlink_callback *cb,
			    const struct inet_diag_req_v2 *req)
{
	struct sk_buff *in_skb = cb->skb;
	bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
	struct net *net = sock_net(in_skb->sk);
	struct sk_buff *rep;
@@ -479,10 +465,7 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
		goto out;
	}

	err = sk_diag_fill(sk, rep, req,
			   sk_user_ns(NETLINK_CB(in_skb).sk),
			   NETLINK_CB(in_skb).portid,
			   nlh->nlmsg_seq, 0, nlh, net_admin);
	err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
	if (err < 0) {
		WARN_ON(err == -EMSGSIZE);
		nlmsg_free(rep);
@@ -509,14 +492,19 @@ static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
	int err;

	handler = inet_diag_lock_handler(req->sdiag_protocol);
	if (IS_ERR(handler))
	if (IS_ERR(handler)) {
		err = PTR_ERR(handler);
	else if (cmd == SOCK_DIAG_BY_FAMILY)
		err = handler->dump_one(in_skb, nlh, req);
	else if (cmd == SOCK_DESTROY && handler->destroy)
	} else if (cmd == SOCK_DIAG_BY_FAMILY) {
		struct netlink_callback cb = {
			.nlh = nlh,
			.skb = in_skb,
		};
		err = handler->dump_one(&cb, req);
	} else if (cmd == SOCK_DESTROY && handler->destroy) {
		err = handler->destroy(in_skb, req);
	else
	} else {
		err = -EOPNOTSUPP;
	}
	inet_diag_unlock_handler(handler);

	return err;
@@ -847,23 +835,6 @@ static int inet_diag_bc_audit(const struct nlattr *attr,
	return len == 0 ? 0 : -EINVAL;
}

static int inet_csk_diag_dump(struct sock *sk,
			      struct sk_buff *skb,
			      struct netlink_callback *cb,
			      const struct inet_diag_req_v2 *r,
			      const struct nlattr *bc,
			      bool net_admin)
{
	if (!inet_diag_bc_sk(bc, sk))
		return 0;

	return inet_csk_diag_fill(sk, skb, r,
				  sk_user_ns(NETLINK_CB(cb->skb).sk),
				  NETLINK_CB(cb->skb).portid,
				  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
				  net_admin);
}

static void twsk_build_assert(void)
{
	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
@@ -935,8 +906,12 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
				    r->id.idiag_sport)
					goto next_listen;

				if (inet_csk_diag_dump(sk, skb, cb, r,
						       bc, net_admin) < 0) {
				if (!inet_diag_bc_sk(bc, sk))
					goto next_listen;

				if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
						      cb, r, NLM_F_MULTI,
						      net_admin) < 0) {
					spin_unlock(&ilb->lock);
					goto done;
				}
@@ -1014,11 +989,8 @@ next_normal:
		res = 0;
		for (idx = 0; idx < accum; idx++) {
			if (res >= 0) {
				res = sk_diag_fill(sk_arr[idx], skb, r,
					   sk_user_ns(NETLINK_CB(cb->skb).sk),
					   NETLINK_CB(cb->skb).portid,
					   cb->nlh->nlmsg_seq, NLM_F_MULTI,
					   cb->nlh, net_admin);
				res = sk_diag_fill(sk_arr[idx], skb, cb, r,
						   NLM_F_MULTI, net_admin);
				if (res < 0)
					num = num_arr[idx];
			}
+6 −12
Original line number Diff line number Diff line
@@ -87,15 +87,16 @@ out_unlock:
	return sk ? sk : ERR_PTR(-ENOENT);
}

static int raw_diag_dump_one(struct sk_buff *in_skb,
			     const struct nlmsghdr *nlh,
static int raw_diag_dump_one(struct netlink_callback *cb,
			     const struct inet_diag_req_v2 *r)
{
	struct net *net = sock_net(in_skb->sk);
	struct sk_buff *in_skb = cb->skb;
	struct sk_buff *rep;
	struct sock *sk;
	struct net *net;
	int err;

	net = sock_net(in_skb->sk);
	sk = raw_sock_get(net, r);
	if (IS_ERR(sk))
		return PTR_ERR(sk);
@@ -108,10 +109,7 @@ static int raw_diag_dump_one(struct sk_buff *in_skb,
		return -ENOMEM;
	}

	err = inet_sk_diag_fill(sk, NULL, rep, r,
				sk_user_ns(NETLINK_CB(in_skb).sk),
				NETLINK_CB(in_skb).portid,
				nlh->nlmsg_seq, 0, nlh,
	err = inet_sk_diag_fill(sk, NULL, rep, cb, r, 0,
				netlink_net_capable(in_skb, CAP_NET_ADMIN));
	sock_put(sk);

@@ -136,11 +134,7 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb,
	if (!inet_diag_bc_sk(bc, sk))
		return 0;

	return inet_sk_diag_fill(sk, NULL, skb, r,
				 sk_user_ns(NETLINK_CB(cb->skb).sk),
				 NETLINK_CB(cb->skb).portid,
				 cb->nlh->nlmsg_seq, NLM_F_MULTI,
				 cb->nlh, net_admin);
	return inet_sk_diag_fill(sk, NULL, skb, cb, r, NLM_F_MULTI, net_admin);
}

static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
+2 −2
Original line number Diff line number Diff line
@@ -184,10 +184,10 @@ static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
	inet_diag_dump_icsk(&tcp_hashinfo, skb, cb, r, bc);
}

static int tcp_diag_dump_one(struct sk_buff *in_skb, const struct nlmsghdr *nlh,
static int tcp_diag_dump_one(struct netlink_callback *cb,
			     const struct inet_diag_req_v2 *req)
{
	return inet_diag_dump_one_icsk(&tcp_hashinfo, in_skb, nlh, req);
	return inet_diag_dump_one_icsk(&tcp_hashinfo, cb, req);
}

#ifdef CONFIG_INET_DIAG_DESTROY
Loading