diff --git a/drivers/block/ublk_drv.c b/drivers/block/ublk_drv.c index afc07fa170400..353ccdb60729c 100644 --- a/drivers/block/ublk_drv.c +++ b/drivers/block/ublk_drv.c @@ -43,6 +43,7 @@ #include #include #include +#include #include #define UBLK_MINORS (1U << MINORBITS) @@ -62,6 +63,8 @@ struct ublk_rq_data { struct llist_node node; + + struct kref ref; }; struct ublk_uring_cmd_pdu { @@ -181,6 +184,9 @@ struct ublk_params_header { __u32 types; }; +static inline void __ublk_complete_rq(struct request *req); +static void ublk_complete_rq(struct kref *ref); + static dev_t ublk_chr_devt; static struct class *ublk_chr_class; @@ -289,6 +295,45 @@ static int ublk_apply_params(struct ublk_device *ub) return 0; } +static inline bool ublk_need_req_ref(const struct ublk_queue *ubq) +{ + return false; +} + +static inline void ublk_init_req_ref(const struct ublk_queue *ubq, + struct request *req) +{ + if (ublk_need_req_ref(ubq)) { + struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); + + kref_init(&data->ref); + } +} + +static inline bool ublk_get_req_ref(const struct ublk_queue *ubq, + struct request *req) +{ + if (ublk_need_req_ref(ubq)) { + struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); + + return kref_get_unless_zero(&data->ref); + } + + return true; +} + +static inline void ublk_put_req_ref(const struct ublk_queue *ubq, + struct request *req) +{ + if (ublk_need_req_ref(ubq)) { + struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); + + kref_put(&data->ref, ublk_complete_rq); + } else { + __ublk_complete_rq(req); + } +} + static inline bool ublk_need_get_data(const struct ublk_queue *ubq) { return ubq->flags & UBLK_F_NEED_GET_DATA; @@ -625,13 +670,19 @@ static inline bool ubq_daemon_is_dying(struct ublk_queue *ubq) } /* todo: handle partial completion */ -static void ublk_complete_rq(struct request *req) +static inline void __ublk_complete_rq(struct request *req) { struct ublk_queue *ubq = req->mq_hctx->driver_data; struct ublk_io *io = &ubq->ios[req->tag]; unsigned int unmapped_bytes; blk_status_t res = BLK_STS_OK; + /* called from ublk_abort_queue() code path */ + if (io->flags & UBLK_IO_FLAG_ABORTED) { + res = BLK_STS_IOERR; + goto exit; + } + /* failed read IO if nothing is read */ if (!io->res && req_op(req) == REQ_OP_READ) io->res = -EIO; @@ -671,6 +722,15 @@ static void ublk_complete_rq(struct request *req) blk_mq_end_request(req, res); } +static void ublk_complete_rq(struct kref *ref) +{ + struct ublk_rq_data *data = container_of(ref, struct ublk_rq_data, + ref); + struct request *req = blk_mq_rq_from_pdu(data); + + __ublk_complete_rq(req); +} + /* * Since __ublk_rq_task_work always fails requests immediately during * exiting, __ublk_fail_req() is only called from abort context during @@ -689,7 +749,7 @@ static void __ublk_fail_req(struct ublk_queue *ubq, struct ublk_io *io, if (ublk_queue_can_use_recovery_reissue(ubq)) blk_mq_requeue_request(req, false); else - blk_mq_end_request(req, BLK_STS_IOERR); + ublk_put_req_ref(ubq, req); } } @@ -798,6 +858,7 @@ static inline void __ublk_rq_task_work(struct request *req, mapped_bytes >> 9; } + ublk_init_req_ref(ubq, req); ubq_complete_io_cmd(io, UBLK_IO_RES_OK, issue_flags); } @@ -1002,7 +1063,7 @@ static void ublk_commit_completion(struct ublk_device *ub, req = blk_mq_tag_to_rq(ub->tag_set.tags[qid], tag); if (req && likely(!blk_should_fake_timeout(req->q))) - ublk_complete_rq(req); + ublk_put_req_ref(ubq, req); } /*