io_uring: pass in io_kiocb to fill/add CQ handlers

This is in preparation for handling CQ ring overflow a bit smarter. We
should not have any functional changes in this patch. Most of the
changes are fairly straight forward, the only ones that stick out a bit
are the ones that change __io_free_req() to take the reference count
into account. If the request hasn't been submitted yet, we know it's
safe to simply ignore references and free it. But let's clean these up
too, as later patches will depend on the caller doing the right thing if
the completion logging grabs a reference to the request.

Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index d8e15cc..91103fc 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -369,10 +369,10 @@
 };
 
 static void io_wq_submit_work(struct io_wq_work **workptr);
-static void io_cqring_fill_event(struct io_ring_ctx *ctx, u64 ki_user_data,
-				 long res);
+static void io_cqring_fill_event(struct io_kiocb *req, long res);
 static void __io_free_req(struct io_kiocb *req);
 static void io_put_req(struct io_kiocb *req, struct io_kiocb **nxtptr);
+static void io_double_put_req(struct io_kiocb *req);
 
 static struct kmem_cache *req_cachep;
 
@@ -535,8 +535,8 @@
 	if (ret != -1) {
 		atomic_inc(&req->ctx->cq_timeouts);
 		list_del_init(&req->list);
-		io_cqring_fill_event(req->ctx, req->user_data, 0);
-		__io_free_req(req);
+		io_cqring_fill_event(req, 0);
+		io_put_req(req, NULL);
 	}
 }
 
@@ -588,12 +588,12 @@
 	return &rings->cqes[tail & ctx->cq_mask];
 }
 
-static void io_cqring_fill_event(struct io_ring_ctx *ctx, u64 ki_user_data,
-				 long res)
+static void io_cqring_fill_event(struct io_kiocb *req, long res)
 {
+	struct io_ring_ctx *ctx = req->ctx;
 	struct io_uring_cqe *cqe;
 
-	trace_io_uring_complete(ctx, ki_user_data, res);
+	trace_io_uring_complete(ctx, req->user_data, res);
 
 	/*
 	 * If we can't get a cq entry, userspace overflowed the
@@ -602,7 +602,7 @@
 	 */
 	cqe = io_get_cqring(ctx);
 	if (cqe) {
-		WRITE_ONCE(cqe->user_data, ki_user_data);
+		WRITE_ONCE(cqe->user_data, req->user_data);
 		WRITE_ONCE(cqe->res, res);
 		WRITE_ONCE(cqe->flags, 0);
 	} else {
@@ -621,13 +621,13 @@
 		eventfd_signal(ctx->cq_ev_fd, 1);
 }
 
-static void io_cqring_add_event(struct io_ring_ctx *ctx, u64 user_data,
-				long res)
+static void io_cqring_add_event(struct io_kiocb *req, long res)
 {
+	struct io_ring_ctx *ctx = req->ctx;
 	unsigned long flags;
 
 	spin_lock_irqsave(&ctx->completion_lock, flags);
-	io_cqring_fill_event(ctx, user_data, res);
+	io_cqring_fill_event(req, res);
 	io_commit_cqring(ctx);
 	spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
@@ -721,10 +721,10 @@
 
 	ret = hrtimer_try_to_cancel(&req->timeout.timer);
 	if (ret != -1) {
-		io_cqring_fill_event(ctx, req->user_data, -ECANCELED);
+		io_cqring_fill_event(req, -ECANCELED);
 		io_commit_cqring(ctx);
 		req->flags &= ~REQ_F_LINK;
-		__io_free_req(req);
+		io_put_req(req, NULL);
 		return true;
 	}
 
@@ -795,8 +795,8 @@
 		    link->submit.sqe->opcode == IORING_OP_LINK_TIMEOUT) {
 			io_link_cancel_timeout(ctx, link);
 		} else {
-			io_cqring_fill_event(ctx, link->user_data, -ECANCELED);
-			__io_free_req(link);
+			io_cqring_fill_event(link, -ECANCELED);
+			io_double_put_req(link);
 		}
 	}
 
@@ -866,6 +866,13 @@
 	}
 }
 
+static void io_double_put_req(struct io_kiocb *req)
+{
+	/* drop both submit and complete references */
+	if (refcount_sub_and_test(2, &req->refs))
+		__io_free_req(req);
+}
+
 static unsigned io_cqring_events(struct io_ring_ctx *ctx)
 {
 	struct io_rings *rings = ctx->rings;
@@ -898,7 +905,7 @@
 		req = list_first_entry(done, struct io_kiocb, list);
 		list_del(&req->list);
 
-		io_cqring_fill_event(ctx, req->user_data, req->result);
+		io_cqring_fill_event(req, req->result);
 		(*nr_events)++;
 
 		if (refcount_dec_and_test(&req->refs)) {
@@ -1094,7 +1101,7 @@
 
 	if ((req->flags & REQ_F_LINK) && res != req->result)
 		req->flags |= REQ_F_FAIL_LINK;
-	io_cqring_add_event(req->ctx, req->user_data, res);
+	io_cqring_add_event(req, res);
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res, long res2)
@@ -1595,15 +1602,14 @@
 /*
  * IORING_OP_NOP just posts a completion event, nothing else.
  */
-static int io_nop(struct io_kiocb *req, u64 user_data)
+static int io_nop(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
-	long err = 0;
 
 	if (unlikely(ctx->flags & IORING_SETUP_IOPOLL))
 		return -EINVAL;
 
-	io_cqring_add_event(ctx, user_data, err);
+	io_cqring_add_event(req, 0);
 	io_put_req(req, NULL);
 	return 0;
 }
@@ -1650,7 +1656,7 @@
 
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	io_put_req(req, nxt);
 	return 0;
 }
@@ -1697,7 +1703,7 @@
 
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	io_put_req(req, nxt);
 	return 0;
 }
@@ -1733,7 +1739,7 @@
 			return ret;
 	}
 
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
 	io_put_req(req, nxt);
@@ -1789,7 +1795,7 @@
 	}
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	io_put_req(req, nxt);
 	return 0;
 #else
@@ -1850,7 +1856,7 @@
 	}
 	spin_unlock_irq(&ctx->completion_lock);
 
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
 	io_put_req(req, NULL);
@@ -1861,7 +1867,7 @@
 			     __poll_t mask)
 {
 	req->poll.done = true;
-	io_cqring_fill_event(ctx, req->user_data, mangle_poll(mask));
+	io_cqring_fill_event(req, mangle_poll(mask));
 	io_commit_cqring(ctx);
 }
 
@@ -2055,7 +2061,7 @@
 		list_del_init(&req->list);
 	}
 
-	io_cqring_fill_event(ctx, req->user_data, -ETIME);
+	io_cqring_fill_event(req, -ETIME);
 	io_commit_cqring(ctx);
 	spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
@@ -2099,7 +2105,7 @@
 	/* didn't find timeout */
 	if (ret) {
 fill_ev:
-		io_cqring_fill_event(ctx, req->user_data, ret);
+		io_cqring_fill_event(req, ret);
 		io_commit_cqring(ctx);
 		spin_unlock_irq(&ctx->completion_lock);
 		io_cqring_ev_posted(ctx);
@@ -2115,8 +2121,8 @@
 		goto fill_ev;
 	}
 
-	io_cqring_fill_event(ctx, req->user_data, 0);
-	io_cqring_fill_event(ctx, treq->user_data, -ECANCELED);
+	io_cqring_fill_event(req, 0);
+	io_cqring_fill_event(treq, -ECANCELED);
 	io_commit_cqring(ctx);
 	spin_unlock_irq(&ctx->completion_lock);
 	io_cqring_ev_posted(ctx);
@@ -2256,7 +2262,7 @@
 
 	if (ret < 0 && (req->flags & REQ_F_LINK))
 		req->flags |= REQ_F_FAIL_LINK;
-	io_cqring_add_event(req->ctx, sqe->user_data, ret);
+	io_cqring_add_event(req, ret);
 	io_put_req(req, nxt);
 	return 0;
 }
@@ -2295,12 +2301,10 @@
 	int ret, opcode;
 	struct sqe_submit *s = &req->submit;
 
-	req->user_data = READ_ONCE(s->sqe->user_data);
-
 	opcode = READ_ONCE(s->sqe->opcode);
 	switch (opcode) {
 	case IORING_OP_NOP:
-		ret = io_nop(req, req->user_data);
+		ret = io_nop(req);
 		break;
 	case IORING_OP_READV:
 		if (unlikely(s->sqe->buf_index))
@@ -2409,7 +2413,7 @@
 	if (ret) {
 		if (req->flags & REQ_F_LINK)
 			req->flags |= REQ_F_FAIL_LINK;
-		io_cqring_add_event(ctx, sqe->user_data, ret);
+		io_cqring_add_event(req, ret);
 		io_put_req(req, NULL);
 	}
 
@@ -2539,7 +2543,7 @@
 		ret = io_async_cancel_one(ctx, user_data);
 	}
 
-	io_cqring_add_event(ctx, req->user_data, ret);
+	io_cqring_add_event(req, ret);
 	io_put_req(req, NULL);
 	return HRTIMER_NORESTART;
 }
@@ -2582,7 +2586,7 @@
 		 * failed by the regular submission path.
 		 */
 		list_del(&nxt->list);
-		io_cqring_fill_event(ctx, nxt->user_data, ret);
+		io_cqring_fill_event(nxt, ret);
 		trace_io_uring_fail_link(req, nxt);
 		io_commit_cqring(ctx);
 		io_put_req(nxt, NULL);
@@ -2655,7 +2659,7 @@
 
 	/* and drop final reference, if we failed */
 	if (ret) {
-		io_cqring_add_event(ctx, req->user_data, ret);
+		io_cqring_add_event(req, ret);
 		if (req->flags & REQ_F_LINK)
 			req->flags |= REQ_F_FAIL_LINK;
 		io_put_req(req, NULL);
@@ -2671,8 +2675,8 @@
 	ret = io_req_defer(ctx, req);
 	if (ret) {
 		if (ret != -EIOCBQUEUED) {
-			io_cqring_add_event(ctx, req->submit.sqe->user_data, ret);
-			io_free_req(req, NULL);
+			io_cqring_add_event(req, ret);
+			io_double_put_req(req);
 		}
 		return 0;
 	}
@@ -2698,8 +2702,8 @@
 	ret = io_req_defer(ctx, req);
 	if (ret) {
 		if (ret != -EIOCBQUEUED) {
-			io_cqring_add_event(ctx, req->submit.sqe->user_data, ret);
-			io_free_req(req, NULL);
+			io_cqring_add_event(req, ret);
+			io_double_put_req(req);
 			__io_free_req(shadow);
 			return 0;
 		}
@@ -2732,6 +2736,8 @@
 	struct sqe_submit *s = &req->submit;
 	int ret;
 
+	req->user_data = s->sqe->user_data;
+
 	/* enforce forwards compatibility on users */
 	if (unlikely(s->sqe->flags & ~SQE_VALID_FLAGS)) {
 		ret = -EINVAL;
@@ -2741,13 +2747,11 @@
 	ret = io_req_set_file(ctx, state, req);
 	if (unlikely(ret)) {
 err_req:
-		io_cqring_add_event(ctx, s->sqe->user_data, ret);
-		io_free_req(req, NULL);
+		io_cqring_add_event(req, ret);
+		io_double_put_req(req);
 		return;
 	}
 
-	req->user_data = s->sqe->user_data;
-
 	/*
 	 * If we already have a head request, queue this one for async
 	 * submittal once the head completes. If we don't have a head but