--- a/kernel/events/core.c
+++ b/kernel/events/core.c
@@ -4816,6 +4816,7 @@ find_get_pmu_context(struct pmu *pmu, struct perf_event_context *ctx,
 			raw_spin_lock_irq(&ctx->lock);
 			list_add(&epc->pmu_ctx_entry, &ctx->pmu_ctx_list);
 			epc->ctx = ctx;
+			get_ctx(ctx);
 			raw_spin_unlock_irq(&ctx->lock);
 		} else {
 			WARN_ON_ONCE(epc->ctx != ctx);
@@ -4862,6 +4863,7 @@ find_get_pmu_context(struct pmu *pmu, struct perf_event_context *ctx,
 
 	list_add(&epc->pmu_ctx_entry, &ctx->pmu_ctx_list);
 	epc->ctx = ctx;
+	get_ctx(ctx);
 
 found_epc:
 	if (task_ctx_data && !epc->task_ctx_data) {
@@ -4914,6 +4916,7 @@ static void put_pmu_ctx(struct perf_event_pmu_context *epc)
 		list_del_init(&epc->pmu_ctx_entry);
 		epc->ctx = NULL;
 		raw_spin_unlock_irqrestore(&ctx->lock, flags);
+		put_ctx(ctx);
 	}
 
 	WARN_ON_ONCE(!list_empty(&epc->pinned_active));
@@ -13021,7 +13024,6 @@ static void perf_event_exit_task_context(struct task_struct *child)
 	 * and mark the context dead.
 	 */
 	RCU_INIT_POINTER(child->perf_event_ctxp, NULL);
-	put_ctx(child_ctx); /* cannot be last */
 	WRITE_ONCE(child_ctx->task, TASK_TOMBSTONE);
 	put_task_struct(current); /* cannot be last */