diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
index 14ddc40..6eb3421 100644
--- a/net/smc/af_smc.c
+++ b/net/smc/af_smc.c
@@ -243,11 +243,27 @@ struct proto smc_proto6 = {
 };
 EXPORT_SYMBOL_GPL(smc_proto6);
 
+static void smc_fback_restore_callbacks(struct smc_sock *smc)
+{
+	struct sock *clcsk = smc->clcsock->sk;
+
+	write_lock_bh(&clcsk->sk_callback_lock);
+	clcsk->sk_user_data = NULL;
+
+	smc_clcsock_restore_cb(&clcsk->sk_state_change, &smc->clcsk_state_change);
+	smc_clcsock_restore_cb(&clcsk->sk_data_ready, &smc->clcsk_data_ready);
+	smc_clcsock_restore_cb(&clcsk->sk_write_space, &smc->clcsk_write_space);
+	smc_clcsock_restore_cb(&clcsk->sk_error_report, &smc->clcsk_error_report);
+
+	write_unlock_bh(&clcsk->sk_callback_lock);
+}
+
 static void smc_restore_fallback_changes(struct smc_sock *smc)
 {
 	if (smc->clcsock->file) { /* non-accepted sockets have no file yet */
 		smc->clcsock->file->private_data = smc->sk.sk_socket;
 		smc->clcsock->file = NULL;
+		smc_fback_restore_callbacks(smc);
 	}
 }
 
@@ -373,6 +389,7 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock,
 	sk->sk_prot->hash(sk);
 	sk_refcnt_debug_inc(sk);
 	mutex_init(&smc->clcsock_release_lock);
+	smc_init_saved_callbacks(smc);
 
 	return sk;
 }
@@ -744,47 +761,73 @@ static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk,
 
 static void smc_fback_state_change(struct sock *clcsk)
 {
-	struct smc_sock *smc =
-		smc_clcsock_user_data(clcsk);
+	struct smc_sock *smc;
 
-	if (!smc)
-		return;
-	smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change);
+	read_lock_bh(&clcsk->sk_callback_lock);
+	smc = smc_clcsock_user_data(clcsk);
+	if (smc)
+		smc_fback_forward_wakeup(smc, clcsk,
+					 smc->clcsk_state_change);
+	read_unlock_bh(&clcsk->sk_callback_lock);
 }
 
 static void smc_fback_data_ready(struct sock *clcsk)
 {
-	struct smc_sock *smc =
-		smc_clcsock_user_data(clcsk);
+	struct smc_sock *smc;
 
-	if (!smc)
-		return;
-	smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready);
+	read_lock_bh(&clcsk->sk_callback_lock);
+	smc = smc_clcsock_user_data(clcsk);
+	if (smc)
+		smc_fback_forward_wakeup(smc, clcsk,
+					 smc->clcsk_data_ready);
+	read_unlock_bh(&clcsk->sk_callback_lock);
 }
 
 static void smc_fback_write_space(struct sock *clcsk)
 {
-	struct smc_sock *smc =
-		smc_clcsock_user_data(clcsk);
+	struct smc_sock *smc;
 
-	if (!smc)
-		return;
-	smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space);
+	read_lock_bh(&clcsk->sk_callback_lock);
+	smc = smc_clcsock_user_data(clcsk);
+	if (smc)
+		smc_fback_forward_wakeup(smc, clcsk,
+					 smc->clcsk_write_space);
+	read_unlock_bh(&clcsk->sk_callback_lock);
 }
 
 static void smc_fback_error_report(struct sock *clcsk)
 {
-	struct smc_sock *smc =
-		smc_clcsock_user_data(clcsk);
+	struct smc_sock *smc;
 
-	if (!smc)
-		return;
-	smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report);
+	read_lock_bh(&clcsk->sk_callback_lock);
+	smc = smc_clcsock_user_data(clcsk);
+	if (smc)
+		smc_fback_forward_wakeup(smc, clcsk,
+					 smc->clcsk_error_report);
+	read_unlock_bh(&clcsk->sk_callback_lock);
+}
+
+static void smc_fback_replace_callbacks(struct smc_sock *smc)
+{
+	struct sock *clcsk = smc->clcsock->sk;
+
+	write_lock_bh(&clcsk->sk_callback_lock);
+	clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
+
+	smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change,
+			       &smc->clcsk_state_change);
+	smc_clcsock_replace_cb(&clcsk->sk_data_ready, smc_fback_data_ready,
+			       &smc->clcsk_data_ready);
+	smc_clcsock_replace_cb(&clcsk->sk_write_space, smc_fback_write_space,
+			       &smc->clcsk_write_space);
+	smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report,
+			       &smc->clcsk_error_report);
+
+	write_unlock_bh(&clcsk->sk_callback_lock);
 }
 
 static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
 {
-	struct sock *clcsk;
 	int rc = 0;
 
 	mutex_lock(&smc->clcsock_release_lock);
@@ -792,10 +835,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
 		rc = -EBADF;
 		goto out;
 	}
-	clcsk = smc->clcsock->sk;
 
-	if (smc->use_fallback)
-		goto out;
 	smc->use_fallback = true;
 	smc->fallback_rsn = reason_code;
 	smc_stat_fallback(smc);
@@ -810,18 +850,7 @@ static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
 		 * in smc sk->sk_wq and they should be woken up
 		 * as clcsock's wait queue is woken up.
 		 */
-		smc->clcsk_state_change = clcsk->sk_state_change;
-		smc->clcsk_data_ready = clcsk->sk_data_ready;
-		smc->clcsk_write_space = clcsk->sk_write_space;
-		smc->clcsk_error_report = clcsk->sk_error_report;
-
-		clcsk->sk_state_change = smc_fback_state_change;
-		clcsk->sk_data_ready = smc_fback_data_ready;
-		clcsk->sk_write_space = smc_fback_write_space;
-		clcsk->sk_error_report = smc_fback_error_report;
-
-		smc->clcsock->sk->sk_user_data =
-			(void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
+		smc_fback_replace_callbacks(smc);
 	}
 out:
 	mutex_unlock(&smc->clcsock_release_lock);
@@ -1594,6 +1623,19 @@ static int smc_clcsock_accept(struct smc_sock *lsmc, struct smc_sock **new_smc)
 	 * function; switch it back to the original sk_data_ready function
 	 */
 	new_clcsock->sk->sk_data_ready = lsmc->clcsk_data_ready;
+
+	/* if new clcsock has also inherited the fallback-specific callback
+	 * functions, switch them back to the original ones.
+	 */
+	if (lsmc->use_fallback) {
+		if (lsmc->clcsk_state_change)
+			new_clcsock->sk->sk_state_change = lsmc->clcsk_state_change;
+		if (lsmc->clcsk_write_space)
+			new_clcsock->sk->sk_write_space = lsmc->clcsk_write_space;
+		if (lsmc->clcsk_error_report)
+			new_clcsock->sk->sk_error_report = lsmc->clcsk_error_report;
+	}
+
 	(*new_smc)->clcsock = new_clcsock;
 out:
 	return rc;
@@ -2353,17 +2395,20 @@ static void smc_tcp_listen_work(struct work_struct *work)
 
 static void smc_clcsock_data_ready(struct sock *listen_clcsock)
 {
-	struct smc_sock *lsmc =
-		smc_clcsock_user_data(listen_clcsock);
+	struct smc_sock *lsmc;
 
+	read_lock_bh(&listen_clcsock->sk_callback_lock);
+	lsmc = smc_clcsock_user_data(listen_clcsock);
 	if (!lsmc)
-		return;
+		goto out;
 	lsmc->clcsk_data_ready(listen_clcsock);
 	if (lsmc->sk.sk_state == SMC_LISTEN) {
 		sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */
 		if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work))
 			sock_put(&lsmc->sk);
 	}
+out:
+	read_unlock_bh(&listen_clcsock->sk_callback_lock);
 }
 
 static int smc_listen(struct socket *sock, int backlog)
@@ -2395,10 +2440,12 @@ static int smc_listen(struct socket *sock, int backlog)
 	/* save original sk_data_ready function and establish
 	 * smc-specific sk_data_ready function
 	 */
-	smc->clcsk_data_ready = smc->clcsock->sk->sk_data_ready;
-	smc->clcsock->sk->sk_data_ready = smc_clcsock_data_ready;
+	write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
 	smc->clcsock->sk->sk_user_data =
 		(void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
+	smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready,
+			       smc_clcsock_data_ready, &smc->clcsk_data_ready);
+	write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
 
 	/* save original ops */
 	smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops;
@@ -2413,7 +2460,11 @@ static int smc_listen(struct socket *sock, int backlog)
 
 	rc = kernel_listen(smc->clcsock, backlog);
 	if (rc) {
-		smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready;
+		write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
+		smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready,
+				       &smc->clcsk_data_ready);
+		smc->clcsock->sk->sk_user_data = NULL;
+		write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
 		goto out;
 	}
 	sk->sk_max_ack_backlog = backlog;
diff --git a/net/smc/smc.h b/net/smc/smc.h
index ea06205..5ed765e 100644
--- a/net/smc/smc.h
+++ b/net/smc/smc.h
@@ -288,12 +288,41 @@ static inline struct smc_sock *smc_sk(const struct sock *sk)
 	return (struct smc_sock *)sk;
 }
 
+static inline void smc_init_saved_callbacks(struct smc_sock *smc)
+{
+	smc->clcsk_state_change	= NULL;
+	smc->clcsk_data_ready	= NULL;
+	smc->clcsk_write_space	= NULL;
+	smc->clcsk_error_report	= NULL;
+}
+
 static inline struct smc_sock *smc_clcsock_user_data(const struct sock *clcsk)
 {
 	return (struct smc_sock *)
 	       ((uintptr_t)clcsk->sk_user_data & ~SK_USER_DATA_NOCOPY);
 }
 
+/* save target_cb in saved_cb, and replace target_cb with new_cb */
+static inline void smc_clcsock_replace_cb(void (**target_cb)(struct sock *),
+					  void (*new_cb)(struct sock *),
+					  void (**saved_cb)(struct sock *))
+{
+	/* only save once */
+	if (!*saved_cb)
+		*saved_cb = *target_cb;
+	*target_cb = new_cb;
+}
+
+/* restore target_cb to saved_cb, and reset saved_cb to NULL */
+static inline void smc_clcsock_restore_cb(void (**target_cb)(struct sock *),
+					  void (**saved_cb)(struct sock *))
+{
+	if (!*saved_cb)
+		return;
+	*target_cb = *saved_cb;
+	*saved_cb = NULL;
+}
+
 extern struct workqueue_struct	*smc_hs_wq;	/* wq for handshake work */
 extern struct workqueue_struct	*smc_close_wq;	/* wq for close work */
 
diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c
index 676cb23..31db743 100644
--- a/net/smc/smc_close.c
+++ b/net/smc/smc_close.c
@@ -214,8 +214,11 @@ int smc_close_active(struct smc_sock *smc)
 		sk->sk_state = SMC_CLOSED;
 		sk->sk_state_change(sk); /* wake up accept */
 		if (smc->clcsock && smc->clcsock->sk) {
-			smc->clcsock->sk->sk_data_ready = smc->clcsk_data_ready;
+			write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
+			smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready,
+					       &smc->clcsk_data_ready);
 			smc->clcsock->sk->sk_user_data = NULL;
+			write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
 			rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR);
 		}
 		smc_close_cleanup_listen(sk);