diff --git a/include/net/sock.h b/include/net/sock.h
index 657873e2d90fb..0063e8410a4e3 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1254,6 +1254,7 @@ struct proto {
 	void			(*enter_memory_pressure)(struct sock *sk);
 	void			(*leave_memory_pressure)(struct sock *sk);
 	atomic_long_t		*memory_allocated;	/* Current allocated memory. */
+	int  __percpu		*per_cpu_fw_alloc;
 	struct percpu_counter	*sockets_allocated;	/* Current number of sockets. */
 
 	/*
@@ -1396,22 +1397,48 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
 	return !!*sk->sk_prot->memory_pressure;
 }
 
+static inline long
+proto_memory_allocated(const struct proto *prot)
+{
+	return max(0L, atomic_long_read(prot->memory_allocated));
+}
+
 static inline long
 sk_memory_allocated(const struct sock *sk)
 {
-	return atomic_long_read(sk->sk_prot->memory_allocated);
+	return proto_memory_allocated(sk->sk_prot);
 }
 
+/* 1 MB per cpu, in page units */
+#define SK_MEMORY_PCPU_RESERVE (1 << (20 - PAGE_SHIFT))
+
 static inline long
 sk_memory_allocated_add(struct sock *sk, int amt)
 {
-	return atomic_long_add_return(amt, sk->sk_prot->memory_allocated);
+	int local_reserve;
+
+	preempt_disable();
+	local_reserve = __this_cpu_add_return(*sk->sk_prot->per_cpu_fw_alloc, amt);
+	if (local_reserve >= SK_MEMORY_PCPU_RESERVE) {
+		__this_cpu_sub(*sk->sk_prot->per_cpu_fw_alloc, local_reserve);
+		atomic_long_add(local_reserve, sk->sk_prot->memory_allocated);
+	}
+	preempt_enable();
+	return sk_memory_allocated(sk);
 }
 
 static inline void
 sk_memory_allocated_sub(struct sock *sk, int amt)
 {
-	atomic_long_sub(amt, sk->sk_prot->memory_allocated);
+	int local_reserve;
+
+	preempt_disable();
+	local_reserve = __this_cpu_sub_return(*sk->sk_prot->per_cpu_fw_alloc, amt);
+	if (local_reserve <= -SK_MEMORY_PCPU_RESERVE) {
+		__this_cpu_sub(*sk->sk_prot->per_cpu_fw_alloc, local_reserve);
+		atomic_long_add(local_reserve, sk->sk_prot->memory_allocated);
+	}
+	preempt_enable();
 }
 
 #define SK_ALLOC_PERCPU_COUNTER_BATCH 16
@@ -1440,12 +1467,6 @@ proto_sockets_allocated_sum_positive(struct proto *prot)
 	return percpu_counter_sum_positive(prot->sockets_allocated);
 }
 
-static inline long
-proto_memory_allocated(struct proto *prot)
-{
-	return atomic_long_read(prot->memory_allocated);
-}
-
 static inline bool
 proto_memory_pressure(struct proto *prot)
 {
@@ -1532,30 +1553,18 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind);
 void __sk_mem_reduce_allocated(struct sock *sk, int amount);
 void __sk_mem_reclaim(struct sock *sk, int amount);
 
-/* We used to have PAGE_SIZE here, but systems with 64KB pages
- * do not necessarily have 16x time more memory than 4KB ones.
- */
-#define SK_MEM_QUANTUM 4096
-#define SK_MEM_QUANTUM_SHIFT ilog2(SK_MEM_QUANTUM)
 #define SK_MEM_SEND	0
 #define SK_MEM_RECV	1
 
-/* sysctl_mem values are in pages, we convert them in SK_MEM_QUANTUM units */
+/* sysctl_mem values are in pages */
 static inline long sk_prot_mem_limits(const struct sock *sk, int index)
 {
-	long val = sk->sk_prot->sysctl_mem[index];
-
-#if PAGE_SIZE > SK_MEM_QUANTUM
-	val <<= PAGE_SHIFT - SK_MEM_QUANTUM_SHIFT;
-#elif PAGE_SIZE < SK_MEM_QUANTUM
-	val >>= SK_MEM_QUANTUM_SHIFT - PAGE_SHIFT;
-#endif
-	return val;
+	return sk->sk_prot->sysctl_mem[index];
 }
 
 static inline int sk_mem_pages(int amt)
 {
-	return (amt + SK_MEM_QUANTUM - 1) >> SK_MEM_QUANTUM_SHIFT;
+	return (amt + PAGE_SIZE - 1) >> PAGE_SHIFT;
 }
 
 static inline bool sk_has_account(struct sock *sk)
@@ -1566,19 +1575,23 @@ static inline bool sk_has_account(struct sock *sk)
 
 static inline bool sk_wmem_schedule(struct sock *sk, int size)
 {
+	int delta;
+
 	if (!sk_has_account(sk))
 		return true;
-	return size <= sk->sk_forward_alloc ||
-		__sk_mem_schedule(sk, size, SK_MEM_SEND);
+	delta = size - sk->sk_forward_alloc;
+	return delta <= 0 || __sk_mem_schedule(sk, delta, SK_MEM_SEND);
 }
 
 static inline bool
 sk_rmem_schedule(struct sock *sk, struct sk_buff *skb, int size)
 {
+	int delta;
+
 	if (!sk_has_account(sk))
 		return true;
-	return size <= sk->sk_forward_alloc ||
-		__sk_mem_schedule(sk, size, SK_MEM_RECV) ||
+	delta = size - sk->sk_forward_alloc;
+	return delta <= 0 || __sk_mem_schedule(sk, delta, SK_MEM_RECV) ||
 		skb_pfmemalloc(skb);
 }
 
@@ -1604,7 +1617,7 @@ static inline void sk_mem_reclaim(struct sock *sk)
 
 	reclaimable = sk->sk_forward_alloc - sk_unused_reserved_mem(sk);
 
-	if (reclaimable >= SK_MEM_QUANTUM)
+	if (reclaimable >= (int)PAGE_SIZE)
 		__sk_mem_reclaim(sk, reclaimable);
 }
 
@@ -1614,19 +1627,6 @@ static inline void sk_mem_reclaim_final(struct sock *sk)
 	sk_mem_reclaim(sk);
 }
 
-static inline void sk_mem_reclaim_partial(struct sock *sk)
-{
-	int reclaimable;
-
-	if (!sk_has_account(sk))
-		return;
-
-	reclaimable = sk->sk_forward_alloc - sk_unused_reserved_mem(sk);
-
-	if (reclaimable > SK_MEM_QUANTUM)
-		__sk_mem_reclaim(sk, reclaimable - 1);
-}
-
 static inline void sk_mem_charge(struct sock *sk, int size)
 {
 	if (!sk_has_account(sk))
@@ -1634,29 +1634,17 @@ static inline void sk_mem_charge(struct sock *sk, int size)
 	sk->sk_forward_alloc -= size;
 }
 
-/* the following macros control memory reclaiming in sk_mem_uncharge()
+/* the following macros control memory reclaiming in mptcp_rmem_uncharge()
  */
 #define SK_RECLAIM_THRESHOLD	(1 << 21)
 #define SK_RECLAIM_CHUNK	(1 << 20)
 
 static inline void sk_mem_uncharge(struct sock *sk, int size)
 {
-	int reclaimable;
-
 	if (!sk_has_account(sk))
 		return;
 	sk->sk_forward_alloc += size;
-	reclaimable = sk->sk_forward_alloc - sk_unused_reserved_mem(sk);
-
-	/* Avoid a possible overflow.
-	 * TCP send queues can make this happen, if sk_mem_reclaim()
-	 * is not called and more than 2 GBytes are released at once.
-	 *
-	 * If we reach 2 MBytes, reclaim 1 MBytes right now, there is
-	 * no need to hold that much forward allocation anyway.
-	 */
-	if (unlikely(reclaimable >= SK_RECLAIM_THRESHOLD))
-		__sk_mem_reclaim(sk, SK_RECLAIM_CHUNK);
+	sk_mem_reclaim(sk);
 }
 
 /*
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 1e99f5c61f849..4794cae4577e4 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -253,6 +253,8 @@ extern long sysctl_tcp_mem[3];
 #define TCP_RACK_NO_DUPTHRESH    0x4 /* Do not use DUPACK threshold in RACK */
 
 extern atomic_long_t tcp_memory_allocated;
+DECLARE_PER_CPU(int, tcp_memory_per_cpu_fw_alloc);
+
 extern struct percpu_counter tcp_sockets_allocated;
 extern unsigned long tcp_memory_pressure;
 
diff --git a/include/net/udp.h b/include/net/udp.h
index b83a003305667..b60eea2e3fae2 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -95,6 +95,7 @@ static inline struct udp_hslot *udp_hashslot2(struct udp_table *table,
 extern struct proto udp_prot;
 
 extern atomic_long_t udp_memory_allocated;
+DECLARE_PER_CPU(int, udp_memory_per_cpu_fw_alloc);
 
 /* sysctl variables for udp */
 extern long sysctl_udp_mem[3];
diff --git a/net/core/datagram.c b/net/core/datagram.c
index 50f4faeea76cc..35791f86bd1a0 100644
--- a/net/core/datagram.c
+++ b/net/core/datagram.c
@@ -320,7 +320,6 @@ EXPORT_SYMBOL(skb_recv_datagram);
 void skb_free_datagram(struct sock *sk, struct sk_buff *skb)
 {
 	consume_skb(skb);
-	sk_mem_reclaim_partial(sk);
 }
 EXPORT_SYMBOL(skb_free_datagram);
 
@@ -336,7 +335,6 @@ void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len)
 	slow = lock_sock_fast(sk);
 	sk_peek_offset_bwd(sk, len);
 	skb_orphan(skb);
-	sk_mem_reclaim_partial(sk);
 	unlock_sock_fast(sk, slow);
 
 	/* skb is now orphaned, can be freed outside of locked section */
@@ -396,7 +394,6 @@ int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags)
 				      NULL);
 
 	kfree_skb(skb);
-	sk_mem_reclaim_partial(sk);
 	return err;
 }
 EXPORT_SYMBOL(skb_kill_datagram);
diff --git a/net/core/sock.c b/net/core/sock.c
index f5062d9e12227..697d5c8e2f0de 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -991,7 +991,7 @@ EXPORT_SYMBOL(sock_set_mark);
 static void sock_release_reserved_memory(struct sock *sk, int bytes)
 {
 	/* Round down bytes to multiple of pages */
-	bytes &= ~(SK_MEM_QUANTUM - 1);
+	bytes = round_down(bytes, PAGE_SIZE);
 
 	WARN_ON(bytes > sk->sk_reserved_mem);
 	sk->sk_reserved_mem -= bytes;
@@ -1028,9 +1028,9 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
 		mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
 		return -ENOMEM;
 	}
-	sk->sk_forward_alloc += pages << SK_MEM_QUANTUM_SHIFT;
+	sk->sk_forward_alloc += pages << PAGE_SHIFT;
 
-	sk->sk_reserved_mem += pages << SK_MEM_QUANTUM_SHIFT;
+	sk->sk_reserved_mem += pages << PAGE_SHIFT;
 
 	return 0;
 }
@@ -2987,7 +2987,6 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
 
 	return 0;
 }
-EXPORT_SYMBOL(__sk_mem_raise_allocated);
 
 /**
  *	__sk_mem_schedule - increase sk_forward_alloc and memory_allocated
@@ -3003,10 +3002,10 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
 {
 	int ret, amt = sk_mem_pages(size);
 
-	sk->sk_forward_alloc += amt << SK_MEM_QUANTUM_SHIFT;
+	sk->sk_forward_alloc += amt << PAGE_SHIFT;
 	ret = __sk_mem_raise_allocated(sk, size, amt, kind);
 	if (!ret)
-		sk->sk_forward_alloc -= amt << SK_MEM_QUANTUM_SHIFT;
+		sk->sk_forward_alloc -= amt << PAGE_SHIFT;
 	return ret;
 }
 EXPORT_SYMBOL(__sk_mem_schedule);
@@ -3029,17 +3028,16 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount)
 	    (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
 		sk_leave_memory_pressure(sk);
 }
-EXPORT_SYMBOL(__sk_mem_reduce_allocated);
 
 /**
  *	__sk_mem_reclaim - reclaim sk_forward_alloc and memory_allocated
  *	@sk: socket
- *	@amount: number of bytes (rounded down to a SK_MEM_QUANTUM multiple)
+ *	@amount: number of bytes (rounded down to a PAGE_SIZE multiple)
  */
 void __sk_mem_reclaim(struct sock *sk, int amount)
 {
-	amount >>= SK_MEM_QUANTUM_SHIFT;
-	sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
+	amount >>= PAGE_SHIFT;
+	sk->sk_forward_alloc -= amount << PAGE_SHIFT;
 	__sk_mem_reduce_allocated(sk, amount);
 }
 EXPORT_SYMBOL(__sk_mem_reclaim);
@@ -3798,6 +3796,10 @@ int proto_register(struct proto *prot, int alloc_slab)
 		pr_err("%s: missing sysctl_mem\n", prot->name);
 		return -EINVAL;
 	}
+	if (prot->memory_allocated && !prot->per_cpu_fw_alloc) {
+		pr_err("%s: missing per_cpu_fw_alloc\n", prot->name);
+		return -EINVAL;
+	}
 	if (alloc_slab) {
 		prot->slab = kmem_cache_create_usercopy(prot->name,
 					prot->obj_size, 0,
diff --git a/net/decnet/af_decnet.c b/net/decnet/af_decnet.c
index dc92a67baea39..aa4f43f52499b 100644
--- a/net/decnet/af_decnet.c
+++ b/net/decnet/af_decnet.c
@@ -149,6 +149,7 @@ static DEFINE_RWLOCK(dn_hash_lock);
 static struct hlist_head dn_sk_hash[DN_SK_HASH_SIZE];
 static struct hlist_head dn_wild_sk;
 static atomic_long_t decnet_memory_allocated;
+static DEFINE_PER_CPU(int, decnet_memory_per_cpu_fw_alloc);
 
 static int __dn_setsockopt(struct socket *sock, int level, int optname,
 		sockptr_t optval, unsigned int optlen, int flags);
@@ -454,7 +455,10 @@ static struct proto dn_proto = {
 	.owner			= THIS_MODULE,
 	.enter_memory_pressure	= dn_enter_memory_pressure,
 	.memory_pressure	= &dn_memory_pressure,
+
 	.memory_allocated	= &decnet_memory_allocated,
+	.per_cpu_fw_alloc	= &decnet_memory_per_cpu_fw_alloc,
+
 	.sysctl_mem		= sysctl_decnet_mem,
 	.sysctl_wmem		= sysctl_decnet_wmem,
 	.sysctl_rmem		= sysctl_decnet_rmem,
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 9984d23a7f3e1..14ebb4ec4a51f 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -294,6 +294,8 @@ EXPORT_SYMBOL(sysctl_tcp_mem);
 
 atomic_long_t tcp_memory_allocated ____cacheline_aligned_in_smp;	/* Current allocated memory. */
 EXPORT_SYMBOL(tcp_memory_allocated);
+DEFINE_PER_CPU(int, tcp_memory_per_cpu_fw_alloc);
+EXPORT_PER_CPU_SYMBOL_GPL(tcp_memory_per_cpu_fw_alloc);
 
 #if IS_ENABLED(CONFIG_SMC)
 DEFINE_STATIC_KEY_FALSE(tcp_have_smc);
@@ -856,9 +858,6 @@ struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, int size, gfp_t gfp,
 {
 	struct sk_buff *skb;
 
-	if (unlikely(tcp_under_memory_pressure(sk)))
-		sk_mem_reclaim_partial(sk);
-
 	skb = alloc_skb_fclone(size + MAX_TCP_HEADER, gfp);
 	if (likely(skb)) {
 		bool mem_scheduled;
@@ -2762,8 +2761,6 @@ void __tcp_close(struct sock *sk, long timeout)
 		__kfree_skb(skb);
 	}
 
-	sk_mem_reclaim(sk);
-
 	/* If socket has been already reset (e.g. in tcp_reset()) - kill it. */
 	if (sk->sk_state == TCP_CLOSE)
 		goto adjudge_to_death;
@@ -2871,7 +2868,6 @@ void __tcp_close(struct sock *sk, long timeout)
 		}
 	}
 	if (sk->sk_state != TCP_CLOSE) {
-		sk_mem_reclaim(sk);
 		if (tcp_check_oom(sk, 0)) {
 			tcp_set_state(sk, TCP_CLOSE);
 			tcp_send_active_reset(sk, GFP_ATOMIC);
@@ -2949,7 +2945,6 @@ void tcp_write_queue_purge(struct sock *sk)
 	}
 	tcp_rtx_queue_purge(sk);
 	INIT_LIST_HEAD(&tcp_sk(sk)->tsorted_sent_queue);
-	sk_mem_reclaim(sk);
 	tcp_clear_all_retrans_hints(tcp_sk(sk));
 	tcp_sk(sk)->packets_out = 0;
 	inet_csk(sk)->icsk_backoff = 0;
@@ -4661,11 +4656,11 @@ void __init tcp_init(void)
 	max_wshare = min(4UL*1024*1024, limit);
 	max_rshare = min(6UL*1024*1024, limit);
 
-	init_net.ipv4.sysctl_tcp_wmem[0] = SK_MEM_QUANTUM;
+	init_net.ipv4.sysctl_tcp_wmem[0] = PAGE_SIZE;
 	init_net.ipv4.sysctl_tcp_wmem[1] = 16*1024;
 	init_net.ipv4.sysctl_tcp_wmem[2] = max(64*1024, max_wshare);
 
-	init_net.ipv4.sysctl_tcp_rmem[0] = SK_MEM_QUANTUM;
+	init_net.ipv4.sysctl_tcp_rmem[0] = PAGE_SIZE;
 	init_net.ipv4.sysctl_tcp_rmem[1] = 131072;
 	init_net.ipv4.sysctl_tcp_rmem[2] = max(131072, max_rshare);
 
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index 2e2a9ece9af27..fdc7beb81b684 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -805,7 +805,6 @@ static void tcp_event_data_recv(struct sock *sk, struct sk_buff *skb)
 			 * restart window, so that we send ACKs quickly.
 			 */
 			tcp_incr_quickack(sk, TCP_MAX_QUICKACKS);
-			sk_mem_reclaim(sk);
 		}
 	}
 	icsk->icsk_ack.lrcvtime = now;
@@ -4390,7 +4389,6 @@ void tcp_fin(struct sock *sk)
 	skb_rbtree_purge(&tp->out_of_order_queue);
 	if (tcp_is_sack(tp))
 		tcp_sack_reset(&tp->rx_opt);
-	sk_mem_reclaim(sk);
 
 	if (!sock_flag(sk, SOCK_DEAD)) {
 		sk->sk_state_change(sk);
@@ -5287,7 +5285,7 @@ static void tcp_collapse_ofo_queue(struct sock *sk)
 		    before(TCP_SKB_CB(skb)->end_seq, start)) {
 			/* Do not attempt collapsing tiny skbs */
 			if (range_truesize != head->truesize ||
-			    end - start >= SKB_WITH_OVERHEAD(SK_MEM_QUANTUM)) {
+			    end - start >= SKB_WITH_OVERHEAD(PAGE_SIZE)) {
 				tcp_collapse(sk, NULL, &tp->out_of_order_queue,
 					     head, skb, start, end);
 			} else {
@@ -5336,7 +5334,6 @@ static bool tcp_prune_ofo_queue(struct sock *sk)
 		tcp_drop_reason(sk, rb_to_skb(node),
 				SKB_DROP_REASON_TCP_OFO_QUEUE_PRUNE);
 		if (!prev || goal <= 0) {
-			sk_mem_reclaim(sk);
 			if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf &&
 			    !tcp_under_memory_pressure(sk))
 				break;
@@ -5383,7 +5380,6 @@ static int tcp_prune_queue(struct sock *sk)
 			     skb_peek(&sk->sk_receive_queue),
 			     NULL,
 			     tp->copied_seq, tp->rcv_nxt);
-	sk_mem_reclaim(sk);
 
 	if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf)
 		return 0;
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index fe8f23b95d32c..fda811a5251f2 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -3045,7 +3045,10 @@ struct proto tcp_prot = {
 	.stream_memory_free	= tcp_stream_memory_free,
 	.sockets_allocated	= &tcp_sockets_allocated,
 	.orphan_count		= &tcp_orphan_count,
+
 	.memory_allocated	= &tcp_memory_allocated,
+	.per_cpu_fw_alloc	= &tcp_memory_per_cpu_fw_alloc,
+
 	.memory_pressure	= &tcp_memory_pressure,
 	.sysctl_mem		= sysctl_tcp_mem,
 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_tcp_wmem),
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 1c054431e3583..8ab98e1aca679 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3367,7 +3367,7 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
 	if (size <= sk->sk_forward_alloc)
 		return;
 	amt = sk_mem_pages(size);
-	sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
+	sk->sk_forward_alloc += amt << PAGE_SHIFT;
 	sk_memory_allocated_add(sk, amt);
 
 	if (mem_cgroup_sockets_enabled && sk->sk_memcg)
diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c
index 20cf4a98c69d8..2208755e8efc6 100644
--- a/net/ipv4/tcp_timer.c
+++ b/net/ipv4/tcp_timer.c
@@ -290,15 +290,13 @@ void tcp_delack_timer_handler(struct sock *sk)
 {
 	struct inet_connection_sock *icsk = inet_csk(sk);
 
-	sk_mem_reclaim_partial(sk);
-
 	if (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) ||
 	    !(icsk->icsk_ack.pending & ICSK_ACK_TIMER))
-		goto out;
+		return;
 
 	if (time_after(icsk->icsk_ack.timeout, jiffies)) {
 		sk_reset_timer(sk, &icsk->icsk_delack_timer, icsk->icsk_ack.timeout);
-		goto out;
+		return;
 	}
 	icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER;
 
@@ -317,10 +315,6 @@ void tcp_delack_timer_handler(struct sock *sk)
 		tcp_send_ack(sk);
 		__NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKS);
 	}
-
-out:
-	if (tcp_under_memory_pressure(sk))
-		sk_mem_reclaim(sk);
 }
 
 
@@ -600,11 +594,11 @@ void tcp_write_timer_handler(struct sock *sk)
 
 	if (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) ||
 	    !icsk->icsk_pending)
-		goto out;
+		return;
 
 	if (time_after(icsk->icsk_timeout, jiffies)) {
 		sk_reset_timer(sk, &icsk->icsk_retransmit_timer, icsk->icsk_timeout);
-		goto out;
+		return;
 	}
 
 	tcp_mstamp_refresh(tcp_sk(sk));
@@ -626,9 +620,6 @@ void tcp_write_timer_handler(struct sock *sk)
 		tcp_probe_timer(sk);
 		break;
 	}
-
-out:
-	sk_mem_reclaim(sk);
 }
 
 static void tcp_write_timer(struct timer_list *t)
@@ -743,8 +734,6 @@ static void tcp_keepalive_timer (struct timer_list *t)
 		elapsed = keepalive_time_when(tp) - elapsed;
 	}
 
-	sk_mem_reclaim(sk);
-
 resched:
 	inet_csk_reset_keepalive_timer (sk, elapsed);
 	goto out;
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index aa9f2ec3dc468..6172b4750a888 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -125,6 +125,8 @@ EXPORT_SYMBOL(sysctl_udp_mem);
 
 atomic_long_t udp_memory_allocated ____cacheline_aligned_in_smp;
 EXPORT_SYMBOL(udp_memory_allocated);
+DEFINE_PER_CPU(int, udp_memory_per_cpu_fw_alloc);
+EXPORT_PER_CPU_SYMBOL_GPL(udp_memory_per_cpu_fw_alloc);
 
 #define MAX_UDP_PORTS 65536
 #define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN)
@@ -1461,11 +1463,11 @@ static void udp_rmem_release(struct sock *sk, int size, int partial,
 
 
 	sk->sk_forward_alloc += size;
-	amt = (sk->sk_forward_alloc - partial) & ~(SK_MEM_QUANTUM - 1);
+	amt = (sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1);
 	sk->sk_forward_alloc -= amt;
 
 	if (amt)
-		__sk_mem_reduce_allocated(sk, amt >> SK_MEM_QUANTUM_SHIFT);
+		__sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT);
 
 	atomic_sub(size, &sk->sk_rmem_alloc);
 
@@ -1558,7 +1560,7 @@ int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb)
 	spin_lock(&list->lock);
 	if (size >= sk->sk_forward_alloc) {
 		amt = sk_mem_pages(size);
-		delta = amt << SK_MEM_QUANTUM_SHIFT;
+		delta = amt << PAGE_SHIFT;
 		if (!__sk_mem_raise_allocated(sk, delta, amt, SK_MEM_RECV)) {
 			err = -ENOBUFS;
 			spin_unlock(&list->lock);
@@ -2946,6 +2948,8 @@ struct proto udp_prot = {
 	.psock_update_sk_prot	= udp_bpf_update_proto,
 #endif
 	.memory_allocated	= &udp_memory_allocated,
+	.per_cpu_fw_alloc	= &udp_memory_per_cpu_fw_alloc,
+
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 	.sysctl_rmem_offset	= offsetof(struct net, ipv4.sysctl_udp_rmem_min),
@@ -3263,8 +3267,8 @@ EXPORT_SYMBOL(udp_flow_hashrnd);
 
 static void __udp_sysctl_init(struct net *net)
 {
-	net->ipv4.sysctl_udp_rmem_min = SK_MEM_QUANTUM;
-	net->ipv4.sysctl_udp_wmem_min = SK_MEM_QUANTUM;
+	net->ipv4.sysctl_udp_rmem_min = PAGE_SIZE;
+	net->ipv4.sysctl_udp_wmem_min = PAGE_SIZE;
 
 #ifdef CONFIG_NET_L3_MASTER_DEV
 	net->ipv4.sysctl_udp_l3mdev_accept = 0;
diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c
index cd1cd68adeec8..6e08a76ae1e7e 100644
--- a/net/ipv4/udplite.c
+++ b/net/ipv4/udplite.c
@@ -51,7 +51,10 @@ struct proto 	udplite_prot = {
 	.unhash		   = udp_lib_unhash,
 	.rehash		   = udp_v4_rehash,
 	.get_port	   = udp_v4_get_port,
+
 	.memory_allocated  = &udp_memory_allocated,
+	.per_cpu_fw_alloc  = &udp_memory_per_cpu_fw_alloc,
+
 	.sysctl_mem	   = sysctl_udp_mem,
 	.obj_size	   = sizeof(struct udp_sock),
 	.h.udp_table	   = &udplite_table,
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index f37dd4aa91c6b..c72448ba6dc9c 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -2159,7 +2159,10 @@ struct proto tcpv6_prot = {
 	.leave_memory_pressure	= tcp_leave_memory_pressure,
 	.stream_memory_free	= tcp_stream_memory_free,
 	.sockets_allocated	= &tcp_sockets_allocated,
+
 	.memory_allocated	= &tcp_memory_allocated,
+	.per_cpu_fw_alloc	= &tcp_memory_per_cpu_fw_alloc,
+
 	.memory_pressure	= &tcp_memory_pressure,
 	.orphan_count		= &tcp_orphan_count,
 	.sysctl_mem		= sysctl_tcp_mem,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 55afd7f39c045..be074f07073a5 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -1740,7 +1740,10 @@ struct proto udpv6_prot = {
 #ifdef CONFIG_BPF_SYSCALL
 	.psock_update_sk_prot	= udp_bpf_update_proto,
 #endif
+
 	.memory_allocated	= &udp_memory_allocated,
+	.per_cpu_fw_alloc	= &udp_memory_per_cpu_fw_alloc,
+
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
 	.sysctl_rmem_offset     = offsetof(struct net, ipv4.sysctl_udp_rmem_min),
diff --git a/net/ipv6/udplite.c b/net/ipv6/udplite.c
index fbb700d3f437e..b707258562597 100644
--- a/net/ipv6/udplite.c
+++ b/net/ipv6/udplite.c
@@ -48,7 +48,10 @@ struct proto udplitev6_prot = {
 	.unhash		   = udp_lib_unhash,
 	.rehash		   = udp_v6_rehash,
 	.get_port	   = udp_v6_get_port,
+
 	.memory_allocated  = &udp_memory_allocated,
+	.per_cpu_fw_alloc  = &udp_memory_per_cpu_fw_alloc,
+
 	.sysctl_mem	   = sysctl_udp_mem,
 	.obj_size	   = sizeof(struct udp6_sock),
 	.h.udp_table	   = &udplite_table,
diff --git a/net/iucv/af_iucv.c b/net/iucv/af_iucv.c
index a0385ddbffcfc..498a0c35b7bb2 100644
--- a/net/iucv/af_iucv.c
+++ b/net/iucv/af_iucv.c
@@ -278,8 +278,6 @@ static void iucv_sock_destruct(struct sock *sk)
 	skb_queue_purge(&sk->sk_receive_queue);
 	skb_queue_purge(&sk->sk_error_queue);
 
-	sk_mem_reclaim(sk);
-
 	if (!sock_flag(sk, SOCK_DEAD)) {
 		pr_err("Attempt to release alive iucv socket %p\n", sk);
 		return;
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 17e13396024ad..e0fb9f96c45c8 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -167,8 +167,8 @@ static bool mptcp_ooo_try_coalesce(struct mptcp_sock *msk, struct sk_buff *to,
 
 static void __mptcp_rmem_reclaim(struct sock *sk, int amount)
 {
-	amount >>= SK_MEM_QUANTUM_SHIFT;
-	mptcp_sk(sk)->rmem_fwd_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
+	amount >>= PAGE_SHIFT;
+	mptcp_sk(sk)->rmem_fwd_alloc -= amount << PAGE_SHIFT;
 	__sk_mem_reduce_allocated(sk, amount);
 }
 
@@ -327,7 +327,7 @@ static bool mptcp_rmem_schedule(struct sock *sk, struct sock *ssk, int size)
 		return true;
 
 	amt = sk_mem_pages(size);
-	amount = amt << SK_MEM_QUANTUM_SHIFT;
+	amount = amt << PAGE_SHIFT;
 	msk->rmem_fwd_alloc += amount;
 	if (!__sk_mem_raise_allocated(sk, size, amt, SK_MEM_RECV)) {
 		if (ssk->sk_forward_alloc < amount) {
@@ -972,10 +972,10 @@ static void __mptcp_mem_reclaim_partial(struct sock *sk)
 
 	lockdep_assert_held_once(&sk->sk_lock.slock);
 
-	if (reclaimable > SK_MEM_QUANTUM)
+	if (reclaimable > (int)PAGE_SIZE)
 		__mptcp_rmem_reclaim(sk, reclaimable - 1);
 
-	sk_mem_reclaim_partial(sk);
+	sk_mem_reclaim(sk);
 }
 
 static void mptcp_mem_reclaim_partial(struct sock *sk)
@@ -3437,7 +3437,10 @@ static struct proto mptcp_prot = {
 	.get_port	= mptcp_get_port,
 	.forward_alloc_get	= mptcp_forward_alloc_get,
 	.sockets_allocated	= &mptcp_sockets_allocated,
+
 	.memory_allocated	= &tcp_memory_allocated,
+	.per_cpu_fw_alloc	= &tcp_memory_per_cpu_fw_alloc,
+
 	.memory_pressure	= &tcp_memory_pressure,
 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_tcp_wmem),
 	.sysctl_rmem_offset	= offsetof(struct net, ipv4.sysctl_tcp_rmem),
diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c
index 35928fefae332..fa500ea3a1f1b 100644
--- a/net/sctp/protocol.c
+++ b/net/sctp/protocol.c
@@ -1523,11 +1523,11 @@ static __init int sctp_init(void)
 	limit = (sysctl_sctp_mem[1]) << (PAGE_SHIFT - 7);
 	max_share = min(4UL*1024*1024, limit);
 
-	sysctl_sctp_rmem[0] = SK_MEM_QUANTUM; /* give each asoc 1 page min */
+	sysctl_sctp_rmem[0] = PAGE_SIZE; /* give each asoc 1 page min */
 	sysctl_sctp_rmem[1] = 1500 * SKB_TRUESIZE(1);
 	sysctl_sctp_rmem[2] = max(sysctl_sctp_rmem[1], max_share);
 
-	sysctl_sctp_wmem[0] = SK_MEM_QUANTUM;
+	sysctl_sctp_wmem[0] = PAGE_SIZE;
 	sysctl_sctp_wmem[1] = 16*1024;
 	sysctl_sctp_wmem[2] = max(64*1024, max_share);
 
diff --git a/net/sctp/sm_statefuns.c b/net/sctp/sm_statefuns.c
index 52edee1322fc3..f6ee7f4040c14 100644
--- a/net/sctp/sm_statefuns.c
+++ b/net/sctp/sm_statefuns.c
@@ -6590,8 +6590,6 @@ static int sctp_eat_data(const struct sctp_association *asoc,
 			pr_debug("%s: under pressure, reneging for tsn:%u\n",
 				 __func__, tsn);
 			deliver = SCTP_CMD_RENEGE;
-		} else {
-			sk_mem_reclaim(sk);
 		}
 	}
 
diff --git a/net/sctp/socket.c b/net/sctp/socket.c
index 6d37d2dfb3da8..171f1a35d2052 100644
--- a/net/sctp/socket.c
+++ b/net/sctp/socket.c
@@ -93,6 +93,7 @@ static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk,
 
 static unsigned long sctp_memory_pressure;
 static atomic_long_t sctp_memory_allocated;
+static DEFINE_PER_CPU(int, sctp_memory_per_cpu_fw_alloc);
 struct percpu_counter sctp_sockets_allocated;
 
 static void sctp_enter_memory_pressure(struct sock *sk)
@@ -1823,9 +1824,6 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc,
 	if (sctp_wspace(asoc) < (int)msg_len)
 		sctp_prsctp_prune(asoc, sinfo, msg_len - sctp_wspace(asoc));
 
-	if (sk_under_memory_pressure(sk))
-		sk_mem_reclaim(sk);
-
 	if (sctp_wspace(asoc) <= 0 || !sk_wmem_schedule(sk, msg_len)) {
 		timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 		err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len);
@@ -9194,8 +9192,6 @@ static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
 			goto do_error;
 		if (signal_pending(current))
 			goto do_interrupted;
-		if (sk_under_memory_pressure(sk))
-			sk_mem_reclaim(sk);
 		if ((int)msg_len <= sctp_wspace(asoc) &&
 		    sk_wmem_schedule(sk, msg_len))
 			break;
@@ -9657,7 +9653,10 @@ struct proto sctp_prot = {
 	.sysctl_wmem =  sysctl_sctp_wmem,
 	.memory_pressure = &sctp_memory_pressure,
 	.enter_memory_pressure = sctp_enter_memory_pressure,
+
 	.memory_allocated = &sctp_memory_allocated,
+	.per_cpu_fw_alloc = &sctp_memory_per_cpu_fw_alloc,
+
 	.sockets_allocated = &sctp_sockets_allocated,
 };
 
@@ -9700,7 +9699,10 @@ struct proto sctpv6_prot = {
 	.sysctl_wmem	= sysctl_sctp_wmem,
 	.memory_pressure = &sctp_memory_pressure,
 	.enter_memory_pressure = sctp_enter_memory_pressure,
+
 	.memory_allocated = &sctp_memory_allocated,
+	.per_cpu_fw_alloc = &sctp_memory_per_cpu_fw_alloc,
+
 	.sockets_allocated = &sctp_sockets_allocated,
 };
 #endif /* IS_ENABLED(CONFIG_IPV6) */
diff --git a/net/sctp/stream_interleave.c b/net/sctp/stream_interleave.c
index 6b13f737ebf2e..bb22b71df7a34 100644
--- a/net/sctp/stream_interleave.c
+++ b/net/sctp/stream_interleave.c
@@ -979,8 +979,6 @@ static void sctp_renege_events(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk,
 
 	if (freed >= needed && sctp_ulpevent_idata(ulpq, chunk, gfp) <= 0)
 		sctp_intl_start_pd(ulpq, gfp);
-
-	sk_mem_reclaim(asoc->base.sk);
 }
 
 static void sctp_intl_stream_abort_pd(struct sctp_ulpq *ulpq, __u16 sid,
diff --git a/net/sctp/ulpqueue.c b/net/sctp/ulpqueue.c
index 407fed46931b9..0a8510a0c5e69 100644
--- a/net/sctp/ulpqueue.c
+++ b/net/sctp/ulpqueue.c
@@ -1100,12 +1100,8 @@ void sctp_ulpq_renege(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk,
 		else if (retval == 1)
 			sctp_ulpq_reasm_drain(ulpq);
 	}
-
-	sk_mem_reclaim(asoc->base.sk);
 }
 
-
-
 /* Notify the application if an association is aborted and in
  * partial delivery mode.  Send up any pending received messages.
  */