Commit 22cad158 authored by Pavel Begunkov's avatar Pavel Begunkov Committed by Jens Axboe
Browse files

io_uring: fix cached_sq_head in io_timeout()



io_timeout() can be executed asynchronously by a worker and without
holding ctx->uring_lock

1. using ctx->cached_sq_head there is racy there
2. it should count events from a moment of timeout's submission, but
not execution

Use req->sequence.

Signed-off-by: default avatarPavel Begunkov <asml.silence@gmail.com>
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 8e2e1faf
Loading
Loading
Loading
Loading
+8 −7
Original line number Diff line number Diff line
@@ -4714,6 +4714,7 @@ static int io_timeout(struct io_kiocb *req)
	struct io_timeout_data *data;
	struct list_head *entry;
	unsigned span = 0;
	u32 seq = req->sequence;

	data = &req->io->timeout;

@@ -4730,7 +4731,7 @@ static int io_timeout(struct io_kiocb *req)
		goto add;
	}

	req->sequence = ctx->cached_sq_head + count - 1;
	req->sequence = seq + count;
	data->seq_offset = count;

	/*
@@ -4740,7 +4741,7 @@ static int io_timeout(struct io_kiocb *req)
	spin_lock_irq(&ctx->completion_lock);
	list_for_each_prev(entry, &ctx->timeout_list) {
		struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list);
		unsigned nxt_sq_head;
		unsigned nxt_seq;
		long long tmp, tmp_nxt;
		u32 nxt_offset = nxt->io->timeout.seq_offset;

@@ -4748,18 +4749,18 @@ static int io_timeout(struct io_kiocb *req)
			continue;

		/*
		 * Since cached_sq_head + count - 1 can overflow, use type long
		 * Since seq + count can overflow, use type long
		 * long to store it.
		 */
		tmp = (long long)ctx->cached_sq_head + count - 1;
		nxt_sq_head = nxt->sequence - nxt_offset + 1;
		tmp_nxt = (long long)nxt_sq_head + nxt_offset - 1;
		tmp = (long long)seq + count;
		nxt_seq = nxt->sequence - nxt_offset;
		tmp_nxt = (long long)nxt_seq + nxt_offset;

		/*
		 * cached_sq_head may overflow, and it will never overflow twice
		 * once there is some timeout req still be valid.
		 */
		if (ctx->cached_sq_head < nxt_sq_head)
		if (seq < nxt_seq)
			tmp += UINT_MAX;

		if (tmp > tmp_nxt)