Commit 4257c8ca authored by Jens Axboe's avatar Jens Axboe
Browse files

net: separate out the msghdr copy from ___sys_{send,recv}msg()



This is in preparation for enabling the io_uring helpers for sendmsg
and recvmsg to first copy the header for validation before continuing
with the operation.

There should be no functional changes in this patch.

Acked-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 8042d6ce
Loading
Loading
Loading
Loading
+95 −46
Original line number Diff line number Diff line
@@ -2264,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
	return err < 0 ? err : 0;
}

static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
			 struct msghdr *msg_sys, unsigned int flags,
			 struct used_address *used_address,
static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
			   unsigned int flags, struct used_address *used_address,
			   unsigned int allowed_msghdr_flags)
{
	struct compat_msghdr __user *msg_compat =
	    (struct compat_msghdr __user *)msg;
	struct sockaddr_storage address;
	struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
	unsigned char ctl[sizeof(struct cmsghdr) + 20]
				__aligned(sizeof(__kernel_size_t));
	/* 20 is size of ipv6_pktinfo */
@@ -2280,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
	int ctl_len;
	ssize_t err;

	msg_sys->msg_name = &address;

	if (MSG_CMSG_COMPAT & flags)
		err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
	else
		err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
	if (err < 0)
		return err;

	err = -ENOBUFS;

	if (msg_sys->msg_controllen > INT_MAX)
		goto out_freeiov;
		goto out;
	flags |= (msg_sys->msg_flags & allowed_msghdr_flags);
	ctl_len = msg_sys->msg_controllen;
	if ((MSG_CMSG_COMPAT & flags) && ctl_len) {
@@ -2300,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
		    cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl,
						     sizeof(ctl));
		if (err)
			goto out_freeiov;
			goto out;
		ctl_buf = msg_sys->msg_control;
		ctl_len = msg_sys->msg_controllen;
	} else if (ctl_len) {
@@ -2309,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
		if (ctl_len > sizeof(ctl)) {
			ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL);
			if (ctl_buf == NULL)
				goto out_freeiov;
				goto out;
		}
		err = -EFAULT;
		/*
@@ -2355,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
out_freectl:
	if (ctl_buf != ctl)
		sock_kfree_s(sock->sk, ctl_buf, ctl_len);
out_freeiov:
out:
	return err;
}

static int sendmsg_copy_msghdr(struct msghdr *msg,
			       struct user_msghdr __user *umsg, unsigned flags,
			       struct iovec **iov)
{
	int err;

	if (flags & MSG_CMSG_COMPAT) {
		struct compat_msghdr __user *msg_compat;

		msg_compat = (struct compat_msghdr __user *) umsg;
		err = get_compat_msghdr(msg, msg_compat, NULL, iov);
	} else {
		err = copy_msghdr_from_user(msg, umsg, NULL, iov);
	}
	if (err < 0)
		return err;

	return 0;
}

static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
			 struct msghdr *msg_sys, unsigned int flags,
			 struct used_address *used_address,
			 unsigned int allowed_msghdr_flags)
{
	struct sockaddr_storage address;
	struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
	ssize_t err;

	msg_sys->msg_name = &address;

	err = sendmsg_copy_msghdr(msg_sys, msg, flags, &iov);
	if (err < 0)
		return err;

	err = ____sys_sendmsg(sock, msg_sys, flags, used_address,
				allowed_msghdr_flags);
	kfree(iov);
	return err;
}
@@ -2474,33 +2500,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
	return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
}

static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
			 struct msghdr *msg_sys, unsigned int flags, int nosec)
static int recvmsg_copy_msghdr(struct msghdr *msg,
			       struct user_msghdr __user *umsg, unsigned flags,
			       struct sockaddr __user **uaddr,
			       struct iovec **iov)
{
	ssize_t err;

	if (MSG_CMSG_COMPAT & flags) {
		struct compat_msghdr __user *msg_compat;

		msg_compat = (struct compat_msghdr __user *) umsg;
		err = get_compat_msghdr(msg, msg_compat, uaddr, iov);
	} else {
		err = copy_msghdr_from_user(msg, umsg, uaddr, iov);
	}
	if (err < 0)
		return err;

	return 0;
}

static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys,
			   struct user_msghdr __user *msg,
			   struct sockaddr __user *uaddr,
			   unsigned int flags, int nosec)
{
	struct compat_msghdr __user *msg_compat =
					(struct compat_msghdr __user *) msg;
	struct iovec iovstack[UIO_FASTIOV];
	struct iovec *iov = iovstack;
	int __user *uaddr_len = COMPAT_NAMELEN(msg);
	struct sockaddr_storage addr;
	unsigned long cmsg_ptr;
	int len;
	ssize_t err;

	/* kernel mode address */
	struct sockaddr_storage addr;

	/* user mode address pointers */
	struct sockaddr __user *uaddr;
	int __user *uaddr_len = COMPAT_NAMELEN(msg);

	msg_sys->msg_name = &addr;

	if (MSG_CMSG_COMPAT & flags)
		err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov);
	else
		err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov);
	if (err < 0)
		return err;

	cmsg_ptr = (unsigned long)msg_sys->msg_control;
	msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);

@@ -2511,7 +2545,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
		flags |= MSG_DONTWAIT;
	err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags);
	if (err < 0)
		goto out_freeiov;
		goto out;
	len = err;

	if (uaddr != NULL) {
@@ -2519,12 +2553,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
					msg_sys->msg_namelen, uaddr,
					uaddr_len);
		if (err < 0)
			goto out_freeiov;
			goto out;
	}
	err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
			 COMPAT_FLAGS(msg));
	if (err)
		goto out_freeiov;
		goto out;
	if (MSG_CMSG_COMPAT & flags)
		err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
				 &msg_compat->msg_controllen);
@@ -2532,10 +2566,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
		err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
				 &msg->msg_controllen);
	if (err)
		goto out_freeiov;
		goto out;
	err = len;
out:
	return err;
}

static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
			 struct msghdr *msg_sys, unsigned int flags, int nosec)
{
	struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
	/* user mode address pointers */
	struct sockaddr __user *uaddr;
	ssize_t err;

	err = recvmsg_copy_msghdr(msg_sys, msg, flags, &uaddr, &iov);
	if (err < 0)
		return err;

out_freeiov:
	err = ____sys_recvmsg(sock, msg_sys, msg, uaddr, flags, nosec);
	kfree(iov);
	return err;
}