Commit 97cf0ef9 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'improve-msg_control-kernel-vs-user-pointer-handling'



Christoph Hellwig says:

====================
improve msg_control kernel vs user pointer handling

this series replace the msg_control in the kernel msghdr structure
with an anonymous union and separate fields for kernel vs user
pointers.  In addition to helping a bit with type safety and reducing
sparse warnings, this also allows to remove the set_fs() in
kernel_recvmsg, helping with an eventual entire removal of set_fs().
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 3242956b 1f466e1f
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -50,7 +50,17 @@ struct msghdr {
	void		*msg_name;	/* ptr to socket address structure */
	int		msg_namelen;	/* size of socket address structure */
	struct iov_iter	msg_iter;	/* data */
	void		*msg_control;	/* ancillary data */

	/*
	 * Ancillary data. msg_control_user is the user buffer used for the
	 * recv* side when msg_control_is_user is set, msg_control is the kernel
	 * buffer used for all other cases.
	 */
	union {
		void		*msg_control;
		void __user	*msg_control_user;
	};
	bool		msg_control_is_user : 1;
	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */
	unsigned int	msg_flags;	/* flags on received message */
	struct kiocb	*msg_iocb;	/* ptr to iocb for async requests */
@@ -94,7 +104,10 @@ struct cmsghdr {

#define CMSG_ALIGN(len) ( ((len)+sizeof(long)-1) & ~(sizeof(long)-1) )

#define CMSG_DATA(cmsg)	((void *)((char *)(cmsg) + sizeof(struct cmsghdr)))
#define CMSG_DATA(cmsg) \
	((void *)(cmsg) + sizeof(struct cmsghdr))
#define CMSG_USER_DATA(cmsg) \
	((void __user *)(cmsg) + sizeof(struct cmsghdr))
#define CMSG_SPACE(len) (sizeof(struct cmsghdr) + CMSG_ALIGN(len))
#define CMSG_LEN(len) (sizeof(struct cmsghdr) + (len))

+3 −2
Original line number Diff line number Diff line
@@ -56,7 +56,8 @@ int __get_compat_msghdr(struct msghdr *kmsg,
	if (kmsg->msg_namelen > sizeof(struct sockaddr_storage))
		kmsg->msg_namelen = sizeof(struct sockaddr_storage);

	kmsg->msg_control = compat_ptr(msg.msg_control);
	kmsg->msg_control_is_user = true;
	kmsg->msg_control_user = compat_ptr(msg.msg_control);
	kmsg->msg_controllen = msg.msg_controllen;

	if (save_addr)
@@ -121,7 +122,7 @@ int get_compat_msghdr(struct msghdr *kmsg,
	((ucmlen) >= sizeof(struct compat_cmsghdr) && \
	 (ucmlen) <= (unsigned long) \
	 ((mhdr)->msg_controllen - \
	  ((char *)(ucmsg) - (char *)(mhdr)->msg_control)))
	  ((char __user *)(ucmsg) - (char __user *)(mhdr)->msg_control_user)))

static inline struct compat_cmsghdr __user *cmsg_compat_nxthdr(struct msghdr *msg,
		struct compat_cmsghdr __user *cmsg, int cmsg_len)
+78 −63
Original line number Diff line number Diff line
@@ -212,16 +212,12 @@ EXPORT_SYMBOL(__scm_send);

int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
{
	struct cmsghdr __user *cm
		= (__force struct cmsghdr __user *)msg->msg_control;
	struct cmsghdr cmhdr;
	int cmlen = CMSG_LEN(len);
	int err;

	if (MSG_CMSG_COMPAT & msg->msg_flags)
	if (msg->msg_flags & MSG_CMSG_COMPAT)
		return put_cmsg_compat(msg, level, type, len, data);

	if (cm==NULL || msg->msg_controllen < sizeof(*cm)) {
	if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) {
		msg->msg_flags |= MSG_CTRUNC;
		return 0; /* XXX: return error? check spec. */
	}
@@ -229,23 +225,30 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
		msg->msg_flags |= MSG_CTRUNC;
		cmlen = msg->msg_controllen;
	}

	if (msg->msg_control_is_user) {
		struct cmsghdr __user *cm = msg->msg_control_user;
		struct cmsghdr cmhdr;

		cmhdr.cmsg_level = level;
		cmhdr.cmsg_type = type;
		cmhdr.cmsg_len = cmlen;
		if (copy_to_user(cm, &cmhdr, sizeof cmhdr) ||
		    copy_to_user(CMSG_USER_DATA(cm), data, cmlen - sizeof(*cm)))
			return -EFAULT;
	} else {
		struct cmsghdr *cm = msg->msg_control;

		cm->cmsg_level = level;
		cm->cmsg_type = type;
		cm->cmsg_len = cmlen;
		memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm));
	}

	err = -EFAULT;
	if (copy_to_user(cm, &cmhdr, sizeof cmhdr))
		goto out;
	if (copy_to_user(CMSG_DATA(cm), data, cmlen - sizeof(struct cmsghdr)))
		goto out;
	cmlen = CMSG_SPACE(len);
	if (msg->msg_controllen < cmlen)
		cmlen = msg->msg_controllen;
	cmlen = min(CMSG_SPACE(len), msg->msg_controllen);
	msg->msg_control += cmlen;
	msg->msg_controllen -= cmlen;
	err = 0;
out:
	return err;
	return 0;
}
EXPORT_SYMBOL(put_cmsg);

@@ -277,59 +280,70 @@ void put_cmsg_scm_timestamping(struct msghdr *msg, struct scm_timestamping_inter
}
EXPORT_SYMBOL(put_cmsg_scm_timestamping);

static int __scm_install_fd(struct file *file, int __user *ufd, int o_flags)
{
	struct socket *sock;
	int new_fd;
	int error;

	error = security_file_receive(file);
	if (error)
		return error;

	new_fd = get_unused_fd_flags(o_flags);
	if (new_fd < 0)
		return new_fd;

	error = put_user(new_fd, ufd);
	if (error) {
		put_unused_fd(new_fd);
		return error;
	}

	/* Bump the usage count and install the file. */
	sock = sock_from_file(file, &error);
	if (sock) {
		sock_update_netprioidx(&sock->sk->sk_cgrp_data);
		sock_update_classid(&sock->sk->sk_cgrp_data);
	}
	fd_install(new_fd, get_file(file));
	return error;
}

static int scm_max_fds(struct msghdr *msg)
{
	if (msg->msg_controllen <= sizeof(struct cmsghdr))
		return 0;
	return (msg->msg_controllen - sizeof(struct cmsghdr)) / sizeof(int);
}

void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm)
{
	struct cmsghdr __user *cm
		= (__force struct cmsghdr __user*)msg->msg_control;

	int fdmax = 0;
	int fdnum = scm->fp->count;
	struct file **fp = scm->fp->fp;
	int __user *cmfptr;
	int o_flags = (msg->msg_flags & MSG_CMSG_CLOEXEC) ? O_CLOEXEC : 0;
	int fdmax = min_t(int, scm_max_fds(msg), scm->fp->count);
	int __user *cmsg_data = CMSG_USER_DATA(cm);
	int err = 0, i;

	if (MSG_CMSG_COMPAT & msg->msg_flags) {
	if (msg->msg_flags & MSG_CMSG_COMPAT) {
		scm_detach_fds_compat(msg, scm);
		return;
	}

	if (msg->msg_controllen > sizeof(struct cmsghdr))
		fdmax = ((msg->msg_controllen - sizeof(struct cmsghdr))
			 / sizeof(int));

	if (fdnum < fdmax)
		fdmax = fdnum;
	/* no use for FD passing from kernel space callers */
	if (WARN_ON_ONCE(!msg->msg_control_is_user))
		return;

	for (i=0, cmfptr=(__force int __user *)CMSG_DATA(cm); i<fdmax;
	     i++, cmfptr++)
	{
		struct socket *sock;
		int new_fd;
		err = security_file_receive(fp[i]);
	for (i = 0; i < fdmax; i++) {
		err = __scm_install_fd(scm->fp->fp[i], cmsg_data + i, o_flags);
		if (err)
			break;
		err = get_unused_fd_flags(MSG_CMSG_CLOEXEC & msg->msg_flags
					  ? O_CLOEXEC : 0);
		if (err < 0)
			break;
		new_fd = err;
		err = put_user(new_fd, cmfptr);
		if (err) {
			put_unused_fd(new_fd);
			break;
		}
		/* Bump the usage count and install the file. */
		sock = sock_from_file(fp[i], &err);
		if (sock) {
			sock_update_netprioidx(&sock->sk->sk_cgrp_data);
			sock_update_classid(&sock->sk->sk_cgrp_data);
		}
		fd_install(new_fd, get_file(fp[i]));
	}

	if (i > 0)
	{
	if (i > 0)  {
		int cmlen = CMSG_LEN(i * sizeof(int));

		err = put_user(SOL_SOCKET, &cm->cmsg_level);
		if (!err)
			err = put_user(SCM_RIGHTS, &cm->cmsg_type);
@@ -343,12 +357,13 @@ void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm)
			msg->msg_controllen -= cmlen;
		}
	}
	if (i < fdnum || (fdnum && fdmax <= 0))

	if (i < scm->fp->count || (scm->fp->count && fdmax <= 0))
		msg->msg_flags |= MSG_CTRUNC;

	/*
	 * All of the files that fit in the message have had their
	 * usage counts incremented, so we just free the list.
	 * All of the files that fit in the message have had their usage counts
	 * incremented, so we just free the list.
	 */
	__scm_destroy(scm);
}
+2 −1
Original line number Diff line number Diff line
@@ -1492,7 +1492,8 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
		if (sk->sk_type != SOCK_STREAM)
			return -ENOPROTOOPT;

		msg.msg_control = (__force void *) optval;
		msg.msg_control_is_user = true;
		msg.msg_control_user = optval;
		msg.msg_controllen = len;
		msg.msg_flags = flags;

+6 −16
Original line number Diff line number Diff line
@@ -924,14 +924,9 @@ EXPORT_SYMBOL(sock_recvmsg);
int kernel_recvmsg(struct socket *sock, struct msghdr *msg,
		   struct kvec *vec, size_t num, size_t size, int flags)
{
	mm_segment_t oldfs = get_fs();
	int result;

	msg->msg_control_is_user = false;
	iov_iter_kvec(&msg->msg_iter, READ, vec, num, size);
	set_fs(KERNEL_DS);
	result = sock_recvmsg(sock, msg, flags);
	set_fs(oldfs);
	return result;
	return sock_recvmsg(sock, msg, flags);
}
EXPORT_SYMBOL(kernel_recvmsg);

@@ -2239,7 +2234,8 @@ int __copy_msghdr_from_user(struct msghdr *kmsg,
	if (copy_from_user(&msg, umsg, sizeof(*umsg)))
		return -EFAULT;

	kmsg->msg_control = (void __force *)msg.msg_control;
	kmsg->msg_control_is_user = true;
	kmsg->msg_control_user = msg.msg_control;
	kmsg->msg_controllen = msg.msg_controllen;
	kmsg->msg_flags = msg.msg_flags;

@@ -2331,16 +2327,10 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
				goto out;
		}
		err = -EFAULT;
		/*
		 * Careful! Before this, msg_sys->msg_control contains a user pointer.
		 * Afterwards, it will be a kernel pointer. Thus the compiler-assisted
		 * checking falls down on this.
		 */
		if (copy_from_user(ctl_buf,
				   (void __user __force *)msg_sys->msg_control,
				   ctl_len))
		if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len))
			goto out_freectl;
		msg_sys->msg_control = ctl_buf;
		msg_sys->msg_control_is_user = false;
	}
	msg_sys->msg_flags = flags;