diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c
index f502d34bcb408..205a2b8a5a845 100644
--- a/net/ipv4/ip_gre.c
+++ b/net/ipv4/ip_gre.c
@@ -179,6 +179,7 @@ static __be16 tnl_flags_to_gre_flags(__be16 tflags)
 	return flags;
 }
 
+/* Fills in tpi and returns header length to be pulled. */
 static int parse_gre_header(struct sk_buff *skb, struct tnl_ptk_info *tpi,
 			    bool *csum_err)
 {
@@ -238,7 +239,7 @@ static int parse_gre_header(struct sk_buff *skb, struct tnl_ptk_info *tpi,
 				return -EINVAL;
 		}
 	}
-	return iptunnel_pull_header(skb, hdr_len, tpi->proto, false);
+	return hdr_len;
 }
 
 static void ipgre_err(struct sk_buff *skb, u32 info,
@@ -341,7 +342,7 @@ static void gre_err(struct sk_buff *skb, u32 info)
 	struct tnl_ptk_info tpi;
 	bool csum_err = false;
 
-	if (parse_gre_header(skb, &tpi, &csum_err)) {
+	if (parse_gre_header(skb, &tpi, &csum_err) < 0) {
 		if (!csum_err)		/* ignore csum errors. */
 			return;
 	}
@@ -419,6 +420,7 @@ static int gre_rcv(struct sk_buff *skb)
 {
 	struct tnl_ptk_info tpi;
 	bool csum_err = false;
+	int hdr_len;
 
 #ifdef CONFIG_NET_IPGRE_BROADCAST
 	if (ipv4_is_multicast(ip_hdr(skb)->daddr)) {
@@ -428,7 +430,10 @@ static int gre_rcv(struct sk_buff *skb)
 	}
 #endif
 
-	if (parse_gre_header(skb, &tpi, &csum_err) < 0)
+	hdr_len = parse_gre_header(skb, &tpi, &csum_err);
+	if (hdr_len < 0)
+		goto drop;
+	if (iptunnel_pull_header(skb, hdr_len, tpi.proto, false) < 0)
 		goto drop;
 
 	if (ipgre_rcv(skb, &tpi) == PACKET_RCVD)