diff --git a/block/blk-mq.c b/block/blk-mq.c
index 041f7b7fa0d6d..211ef367345f2 100644
--- a/block/blk-mq.c
+++ b/block/blk-mq.c
@@ -301,11 +301,12 @@ static struct request *blk_mq_get_request(struct request_queue *q,
 	struct elevator_queue *e = q->elevator;
 	struct request *rq;
 	unsigned int tag;
+	struct blk_mq_ctx *local_ctx = NULL;
 
 	blk_queue_enter_live(q);
 	data->q = q;
 	if (likely(!data->ctx))
-		data->ctx = blk_mq_get_ctx(q);
+		data->ctx = local_ctx = blk_mq_get_ctx(q);
 	if (likely(!data->hctx))
 		data->hctx = blk_mq_map_queue(q, data->ctx->cpu);
 	if (op & REQ_NOWAIT)
@@ -324,6 +325,10 @@ static struct request *blk_mq_get_request(struct request_queue *q,
 
 	tag = blk_mq_get_tag(data);
 	if (tag == BLK_MQ_TAG_FAIL) {
+		if (local_ctx) {
+			blk_mq_put_ctx(local_ctx);
+			data->ctx = NULL;
+		}
 		blk_queue_exit(q);
 		return NULL;
 	}
@@ -356,12 +361,12 @@ struct request *blk_mq_alloc_request(struct request_queue *q, unsigned int op,
 
 	rq = blk_mq_get_request(q, NULL, op, &alloc_data);
 
-	blk_mq_put_ctx(alloc_data.ctx);
-	blk_queue_exit(q);
-
 	if (!rq)
 		return ERR_PTR(-EWOULDBLOCK);
 
+	blk_mq_put_ctx(alloc_data.ctx);
+	blk_queue_exit(q);
+
 	rq->__data_len = 0;
 	rq->__sector = (sector_t) -1;
 	rq->bio = rq->biotail = NULL;
@@ -407,11 +412,11 @@ struct request *blk_mq_alloc_request_hctx(struct request_queue *q,
 
 	rq = blk_mq_get_request(q, NULL, op, &alloc_data);
 
-	blk_queue_exit(q);
-
 	if (!rq)
 		return ERR_PTR(-EWOULDBLOCK);
 
+	blk_queue_exit(q);
+
 	return rq;
 }
 EXPORT_SYMBOL_GPL(blk_mq_alloc_request_hctx);