diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 311cec8e533de..acff12999c06e 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -55,6 +55,8 @@ enum {
 
 static struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
+static struct proto *saved_tcpv4_prot;
+static DEFINE_MUTEX(tcpv4_prot_mutex);
 static LIST_HEAD(device_list);
 static DEFINE_MUTEX(device_mutex);
 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
@@ -690,6 +692,16 @@ static int tls_init(struct sock *sk)
 		mutex_unlock(&tcpv6_prot_mutex);
 	}
 
+	if (ip_ver == TLSV4 &&
+	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+		mutex_lock(&tcpv4_prot_mutex);
+		if (likely(sk->sk_prot != saved_tcpv4_prot)) {
+			build_protos(tls_prots[TLSV4], sk->sk_prot);
+			smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+		}
+		mutex_unlock(&tcpv4_prot_mutex);
+	}
+
 	ctx->tx_conf = TLS_BASE;
 	ctx->rx_conf = TLS_BASE;
 	update_sk_prot(sk, ctx);
@@ -721,8 +733,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 
 static int __init tls_register(void)
 {
-	build_protos(tls_prots[TLSV4], &tcp_prot);
-
 	tls_sw_proto_ops = inet_stream_ops;
 	tls_sw_proto_ops.splice_read = tls_sw_splice_read;