diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
index 16985ca3248d7..624eed345b5d2 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
@@ -724,9 +724,9 @@ static u32 mlx5e_get_fcs(const struct sk_buff *skb)
 	return __get_unaligned_cpu32(fcs_bytes);
 }
 
-static u8 get_ip_proto(struct sk_buff *skb, __be16 proto)
+static u8 get_ip_proto(struct sk_buff *skb, int network_depth, __be16 proto)
 {
-	void *ip_p = skb->data + sizeof(struct ethhdr);
+	void *ip_p = skb->data + network_depth;
 
 	return (proto == htons(ETH_P_IP)) ? ((struct iphdr *)ip_p)->protocol :
 					    ((struct ipv6hdr *)ip_p)->nexthdr;
@@ -755,7 +755,7 @@ static inline void mlx5e_handle_csum(struct net_device *netdev,
 		goto csum_unnecessary;
 
 	if (likely(is_last_ethertype_ip(skb, &network_depth, &proto))) {
-		if (unlikely(get_ip_proto(skb, proto) == IPPROTO_SCTP))
+		if (unlikely(get_ip_proto(skb, network_depth, proto) == IPPROTO_SCTP))
 			goto csum_unnecessary;
 
 		skb->ip_summed = CHECKSUM_COMPLETE;