diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
index ce91d9b2acb9f..f0f219271daf4 100644
--- a/include/linux/bpf-cgroup.h
+++ b/include/linux/bpf-cgroup.h
@@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
 	int __ret = 0;							       \
 	if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {		       \
 		typeof(sk) __sk = sk_to_full_sk(sk);			       \
-		if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&	       \
+		if (__sk && __sk == skb_to_full_sk(skb) &&	       \
 		    cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))	       \
 			__ret = __cgroup_bpf_run_filter_skb(__sk, skb,	       \
 						      CGROUP_INET_EGRESS); \
diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h
index f01dd273bea69..56d8bc5593d3d 100644
--- a/include/net/inet_sock.h
+++ b/include/net/inet_sock.h
@@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
 static inline struct sock *sk_to_full_sk(struct sock *sk)
 {
 #ifdef CONFIG_INET
-	if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+	if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
 		sk = inet_reqsk(sk)->rsk_listener;
+	if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+		sk = NULL;
 #endif
 	return sk;
 }
@@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
 static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
 {
 #ifdef CONFIG_INET
-	if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+	if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
 		sk = ((const struct request_sock *)sk)->rsk_listener;
+	if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+		sk = NULL;
 #endif
 	return sk;
 }
diff --git a/include/net/ip.h b/include/net/ip.h
index bab084df15677..4be0a6a603b2b 100644
--- a/include/net/ip.h
+++ b/include/net/ip.h
@@ -288,7 +288,8 @@ static inline __u8 ip_reply_arg_flowi_flags(const struct ip_reply_arg *arg)
 	return (arg->flags & IP_REPLY_ARG_NOSRCCHECK) ? FLOWI_FLAG_ANYSRC : 0;
 }
 
-void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
+void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
+			   struct sk_buff *skb,
 			   const struct ip_options *sopt,
 			   __be32 daddr, __be32 saddr,
 			   const struct ip_reply_arg *arg,
diff --git a/include/net/sock.h b/include/net/sock.h
index 6da420ab1ee12..bf7fa3db10ae7 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1760,6 +1760,15 @@ void sock_efree(struct sk_buff *skb);
 #ifdef CONFIG_INET
 void sock_edemux(struct sk_buff *skb);
 void sock_pfree(struct sk_buff *skb);
+
+static inline void skb_set_owner_edemux(struct sk_buff *skb, struct sock *sk)
+{
+	skb_orphan(skb);
+	if (refcount_inc_not_zero(&sk->sk_refcnt)) {
+		skb->sk = sk;
+		skb->destructor = sock_edemux;
+	}
+}
 #else
 #define sock_edemux sock_efree
 #endif
@@ -2802,6 +2811,16 @@ static inline bool sk_listener(const struct sock *sk)
 	return (1 << sk->sk_state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV);
 }
 
+/* This helper checks if a socket is a LISTEN or NEW_SYN_RECV or TIME_WAIT
+ * TCP SYNACK messages can be attached to LISTEN or NEW_SYN_RECV (depending on SYNCOOKIE)
+ * TCP RST and ACK can be attached to TIME_WAIT.
+ */
+static inline bool sk_listener_or_tw(const struct sock *sk)
+{
+	return (1 << READ_ONCE(sk->sk_state)) &
+	       (TCPF_LISTEN | TCPF_NEW_SYN_RECV | TCPF_TIME_WAIT);
+}
+
 void sock_enable_timestamp(struct sock *sk, enum sock_flags flag);
 int sock_recv_errqueue(struct sock *sk, struct msghdr *msg, int len, int level,
 		       int type);
diff --git a/net/core/filter.c b/net/core/filter.c
index bd0d08bf76bb8..202c1d386e195 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
 		/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
 		 * sock refcnt is decremented to prevent a request_sock leak.
 		 */
-		if (!sk_fullsock(sk2))
-			sk2 = NULL;
 		if (sk2 != sk) {
 			sock_gen_put(sk);
 			/* Ensure there is no need to bump sk2 refcnt */
@@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
 		/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
 		 * sock refcnt is decremented to prevent a request_sock leak.
 		 */
-		if (!sk_fullsock(sk2))
-			sk2 = NULL;
 		if (sk2 != sk) {
 			sock_gen_put(sk);
 			/* Ensure there is no need to bump sk2 refcnt */
@@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
 {
 	sk = sk_to_full_sk(sk);
 
-	if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+	if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
 		return (unsigned long)sk;
 
 	return (unsigned long)NULL;
diff --git a/net/core/sock.c b/net/core/sock.c
index 083d438d8b6fa..f8c0d4eda888c 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2592,14 +2592,11 @@ void __sock_wfree(struct sk_buff *skb)
 void skb_set_owner_w(struct sk_buff *skb, struct sock *sk)
 {
 	skb_orphan(skb);
-	skb->sk = sk;
 #ifdef CONFIG_INET
-	if (unlikely(!sk_fullsock(sk))) {
-		skb->destructor = sock_edemux;
-		sock_hold(sk);
-		return;
-	}
+	if (unlikely(!sk_fullsock(sk)))
+		return skb_set_owner_edemux(skb, sk);
 #endif
+	skb->sk = sk;
 	skb->destructor = sock_wfree;
 	skb_set_hash_from_sk(skb, sk);
 	/*
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index e5c55a95063dd..0065b1996c947 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -1596,7 +1596,8 @@ static int ip_reply_glue_bits(void *dptr, char *to, int offset,
  *	Generic function to send a packet as reply to another packet.
  *	Used to send some TCP resets/acks so far.
  */
-void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
+void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
+			   struct sk_buff *skb,
 			   const struct ip_options *sopt,
 			   __be32 daddr, __be32 saddr,
 			   const struct ip_reply_arg *arg,
@@ -1662,6 +1663,8 @@ void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb,
 			  arg->csumoffset) = csum_fold(csum_add(nskb->csum,
 								arg->csum));
 		nskb->ip_summed = CHECKSUM_NONE;
+		if (orig_sk)
+			skb_set_owner_edemux(nskb, (struct sock *)orig_sk);
 		if (transmit_time)
 			nskb->tstamp_type = SKB_CLOCK_MONOTONIC;
 		if (txhash)
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 985028434f644..9d3dd101ea713 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -907,7 +907,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb,
 		ctl_sk->sk_mark = 0;
 		ctl_sk->sk_priority = 0;
 	}
-	ip_send_unicast_reply(ctl_sk,
+	ip_send_unicast_reply(ctl_sk, sk,
 			      skb, &TCP_SKB_CB(skb)->header.h4.opt,
 			      ip_hdr(skb)->saddr, ip_hdr(skb)->daddr,
 			      &arg, arg.iov[0].iov_len,
@@ -1021,7 +1021,7 @@ static void tcp_v4_send_ack(const struct sock *sk,
 	ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ?
 			   inet_twsk(sk)->tw_priority : READ_ONCE(sk->sk_priority);
 	transmit_time = tcp_transmit_time(sk);
-	ip_send_unicast_reply(ctl_sk,
+	ip_send_unicast_reply(ctl_sk, sk,
 			      skb, &TCP_SKB_CB(skb)->header.h4.opt,
 			      ip_hdr(skb)->saddr, ip_hdr(skb)->daddr,
 			      &arg, arg.iov[0].iov_len,
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 06200bb111f8b..054244ce51175 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3728,7 +3728,7 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst,
 
 	switch (synack_type) {
 	case TCP_SYNACK_NORMAL:
-		skb_set_owner_w(skb, req_to_sk(req));
+		skb_set_owner_edemux(skb, req_to_sk(req));
 		break;
 	case TCP_SYNACK_COOKIE:
 		/* Under synflood, we do not attach skb to a socket,
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 7634c0be6acbd..597920061a3a0 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -967,6 +967,9 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
 	}
 
 	if (sk) {
+		/* unconstify the socket only to attach it to buff with care. */
+		skb_set_owner_edemux(buff, (struct sock *)sk);
+
 		if (sk->sk_state == TCP_TIME_WAIT)
 			mark = inet_twsk(sk)->tw_mark;
 		else
diff --git a/net/sched/sch_fq.c b/net/sched/sch_fq.c
index aeabf45c9200c..a97638bef6da5 100644
--- a/net/sched/sch_fq.c
+++ b/net/sched/sch_fq.c
@@ -362,8 +362,9 @@ static struct fq_flow *fq_classify(struct Qdisc *sch, struct sk_buff *skb,
 	 * 3) We do not want to rate limit them (eg SYNFLOOD attack),
 	 *    especially if the listener set SO_MAX_PACING_RATE
 	 * 4) We pretend they are orphaned
+	 * TCP can also associate TIME_WAIT sockets with RST or ACK packets.
 	 */
-	if (!sk || sk_listener(sk)) {
+	if (!sk || sk_listener_or_tw(sk)) {
 		unsigned long hash = skb_get_hash(skb) & q->orphan_mask;
 
 		/* By forcing low order bit to 1, we make sure to not