diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 08ddf9d837ae..bab00411ac7a 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -305,10 +305,11 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 	long timeo = sock_sndtimeo(sk, 0);
 	bool free_ctx;
 
+	lock_sock(sk);
+	sk->sk_shutdown = SHUTDOWN_MASK;
 	if (ctx->tx_conf == TLS_SW)
 		tls_sw_cancel_work_tx(ctx);
 
-	lock_sock(sk);
 	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
 
 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index fe27241cd13f..ca161a455ccb 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -2298,7 +2298,7 @@ void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
 
 	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
 	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
-	cancel_delayed_work_sync(&ctx->tx_work.work);
+	cancel_delayed_work(&ctx->tx_work.work);
 }
 
 void tls_sw_release_resources_tx(struct sock *sk)
@@ -2346,6 +2346,7 @@ void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
 {
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
+	flush_delayed_work(&ctx->tx_work.work);
 	kfree(ctx);
 }