diff --git a/include/net/tcp.h b/include/net/tcp.h
index 349204130d84b..1aa9628ae608a 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -917,6 +917,8 @@ struct tcp_congestion_ops {
 	void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample);
 	/* suggest number of segments for each skb to transmit (optional) */
 	u32 (*tso_segs_goal)(struct sock *sk);
+	/* returns the multiplier used in tcp_sndbuf_expand (optional) */
+	u32 (*sndbuf_expand)(struct sock *sk);
 	/* get info for inet_diag (optional) */
 	size_t (*get_info)(struct sock *sk, u32 ext, int *attr,
 			   union tcp_cc_info *info);
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index d9ed4bb96f74a..13a2e70141f56 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -289,6 +289,7 @@ static bool tcp_ecn_rcv_ecn_echo(const struct tcp_sock *tp, const struct tcphdr
 static void tcp_sndbuf_expand(struct sock *sk)
 {
 	const struct tcp_sock *tp = tcp_sk(sk);
+	const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops;
 	int sndmem, per_mss;
 	u32 nr_segs;
 
@@ -309,7 +310,8 @@ static void tcp_sndbuf_expand(struct sock *sk)
 	 * Cubic needs 1.7 factor, rounded to 2 to include
 	 * extra cushion (application might react slowly to POLLOUT)
 	 */
-	sndmem = 2 * nr_segs * per_mss;
+	sndmem = ca_ops->sndbuf_expand ? ca_ops->sndbuf_expand(sk) : 2;
+	sndmem *= nr_segs * per_mss;
 
 	if (sk->sk_sndbuf < sndmem)
 		sk->sk_sndbuf = min(sndmem, sysctl_tcp_wmem[2]);