Commit f499a021 authored by Jens Axboe's avatar Jens Axboe
Browse files

io_uring: ensure async punted connect requests copy data



Just like commit f67676d1 for read/write requests, this one ensures
that the sockaddr data has been copied for IORING_OP_CONNECT if we need
to punt the request to async context.

Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 03b1230c
Loading
Loading
Loading
Loading
+47 −4
Original line number Diff line number Diff line
@@ -308,6 +308,10 @@ struct io_timeout {
	struct io_timeout_data		*data;
};

struct io_async_connect {
	struct sockaddr_storage		address;
};

struct io_async_msghdr {
	struct iovec			fast_iov[UIO_FASTIOV];
	struct iovec			*iov;
@@ -327,6 +331,7 @@ struct io_async_ctx {
	union {
		struct io_async_rw	rw;
		struct io_async_msghdr	msg;
		struct io_async_connect	connect;
	};
};

@@ -2195,11 +2200,26 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
#endif
}

static int io_connect_prep(struct io_kiocb *req, struct io_async_ctx *io)
{
#if defined(CONFIG_NET)
	const struct io_uring_sqe *sqe = req->sqe;
	struct sockaddr __user *addr;
	int addr_len;

	addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
	addr_len = READ_ONCE(sqe->addr2);
	return move_addr_to_kernel(addr, addr_len, &io->connect.address);
#else
	return 0;
#endif
}

static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
		      struct io_kiocb **nxt, bool force_nonblock)
{
#if defined(CONFIG_NET)
	struct sockaddr __user *addr;
	struct io_async_ctx __io, *io;
	unsigned file_flags;
	int addr_len, ret;

@@ -2208,15 +2228,35 @@ static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
	if (sqe->ioprio || sqe->len || sqe->buf_index || sqe->rw_flags)
		return -EINVAL;

	addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
	addr_len = READ_ONCE(sqe->addr2);
	file_flags = force_nonblock ? O_NONBLOCK : 0;

	ret = __sys_connect_file(req->file, addr, addr_len, file_flags);
	if (ret == -EAGAIN && force_nonblock)
	if (req->io) {
		io = req->io;
	} else {
		ret = io_connect_prep(req, &__io);
		if (ret)
			goto out;
		io = &__io;
	}

	ret = __sys_connect_file(req->file, &io->connect.address, addr_len,
					file_flags);
	if (ret == -EAGAIN && force_nonblock) {
		io = kmalloc(sizeof(*io), GFP_KERNEL);
		if (!io) {
			ret = -ENOMEM;
			goto out;
		}
		memcpy(&io->connect, &__io.connect, sizeof(io->connect));
		req->io = io;
		memcpy(&io->sqe, req->sqe, sizeof(*req->sqe));
		req->sqe = &io->sqe;
		return -EAGAIN;
	}
	if (ret == -ERESTARTSYS)
		ret = -EINTR;
out:
	if (ret < 0 && (req->flags & REQ_F_LINK))
		req->flags |= REQ_F_FAIL_LINK;
	io_cqring_add_event(req, ret);
@@ -2832,6 +2872,9 @@ static int io_req_defer_prep(struct io_kiocb *req, struct io_async_ctx *io)
	case IORING_OP_RECVMSG:
		ret = io_recvmsg_prep(req, io);
		break;
	case IORING_OP_CONNECT:
		ret = io_connect_prep(req, io);
		break;
	default:
		req->io = io;
		return 0;
+2 −3
Original line number Diff line number Diff line
@@ -406,9 +406,8 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
			 int __user *upeer_addrlen, int flags);
extern int __sys_socket(int family, int type, int protocol);
extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
extern int __sys_connect_file(struct file *file,
			struct sockaddr __user *uservaddr, int addrlen,
			int file_flags);
extern int __sys_connect_file(struct file *file, struct sockaddr_storage *addr,
			      int addrlen, int file_flags);
extern int __sys_connect(int fd, struct sockaddr __user *uservaddr,
			 int addrlen);
extern int __sys_listen(int fd, int backlog);
+8 −8
Original line number Diff line number Diff line
@@ -1826,26 +1826,22 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
 *	include the -EINPROGRESS status for such sockets.
 */

int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr,
int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
		       int addrlen, int file_flags)
{
	struct socket *sock;
	struct sockaddr_storage address;
	int err;

	sock = sock_from_file(file, &err);
	if (!sock)
		goto out;
	err = move_addr_to_kernel(uservaddr, addrlen, &address);
	if (err < 0)
		goto out;

	err =
	    security_socket_connect(sock, (struct sockaddr *)&address, addrlen);
	    security_socket_connect(sock, (struct sockaddr *)address, addrlen);
	if (err)
		goto out;

	err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen,
	err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen,
				 sock->file->f_flags | file_flags);
out:
	return err;
@@ -1858,7 +1854,11 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)

	f = fdget(fd);
	if (f.file) {
		ret = __sys_connect_file(f.file, uservaddr, addrlen, 0);
		struct sockaddr_storage address;

		ret = move_addr_to_kernel(uservaddr, addrlen, &address);
		if (!ret)
			ret = __sys_connect_file(f.file, &address, addrlen, 0);
		if (f.flags)
			fput(f.file);
	}