diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index 3c8e9d8795cdb..81ac5a970e433 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -1270,6 +1270,7 @@ static int __ip6_append_data(struct sock *sk,
 	struct rt6_info *rt = (struct rt6_info *)cork->dst;
 	struct ipv6_txoptions *opt = v6_cork->opt;
 	int csummode = CHECKSUM_NONE;
+	unsigned int maxnonfragsize, headersize;
 
 	skb = skb_peek_tail(queue);
 	if (!skb) {
@@ -1287,38 +1288,43 @@ static int __ip6_append_data(struct sock *sk,
 	maxfraglen = ((mtu - fragheaderlen) & ~7) + fragheaderlen -
 		     sizeof(struct frag_hdr);
 
-	if (mtu <= sizeof(struct ipv6hdr) + IPV6_MAXPLEN) {
-		unsigned int maxnonfragsize, headersize;
-
-		headersize = sizeof(struct ipv6hdr) +
-			     (opt ? opt->opt_flen + opt->opt_nflen : 0) +
-			     (dst_allfrag(&rt->dst) ?
-			      sizeof(struct frag_hdr) : 0) +
-			     rt->rt6i_nfheader_len;
-
-		if (ip6_sk_ignore_df(sk))
-			maxnonfragsize = sizeof(struct ipv6hdr) + IPV6_MAXPLEN;
-		else
-			maxnonfragsize = mtu;
+	headersize = sizeof(struct ipv6hdr) +
+		     (opt ? opt->opt_flen + opt->opt_nflen : 0) +
+		     (dst_allfrag(&rt->dst) ?
+		      sizeof(struct frag_hdr) : 0) +
+		     rt->rt6i_nfheader_len;
+
+	if (cork->length + length > mtu - headersize && dontfrag &&
+	    (sk->sk_protocol == IPPROTO_UDP ||
+	     sk->sk_protocol == IPPROTO_RAW)) {
+		ipv6_local_rxpmtu(sk, fl6, mtu - headersize +
+				sizeof(struct ipv6hdr));
+		goto emsgsize;
+	}
 
-		/* dontfrag active */
-		if ((cork->length + length > mtu - headersize) && dontfrag &&
-		    (sk->sk_protocol == IPPROTO_UDP ||
-		     sk->sk_protocol == IPPROTO_RAW)) {
-			ipv6_local_rxpmtu(sk, fl6, mtu - headersize +
-						   sizeof(struct ipv6hdr));
-			goto emsgsize;
-		}
+	if (ip6_sk_ignore_df(sk))
+		maxnonfragsize = sizeof(struct ipv6hdr) + IPV6_MAXPLEN;
+	else
+		maxnonfragsize = mtu;
 
-		if (cork->length + length > maxnonfragsize - headersize) {
+	if (cork->length + length > maxnonfragsize - headersize) {
 emsgsize:
-			ipv6_local_error(sk, EMSGSIZE, fl6,
-					 mtu - headersize +
-					 sizeof(struct ipv6hdr));
-			return -EMSGSIZE;
-		}
+		ipv6_local_error(sk, EMSGSIZE, fl6,
+				 mtu - headersize +
+				 sizeof(struct ipv6hdr));
+		return -EMSGSIZE;
 	}
 
+	/* CHECKSUM_PARTIAL only with no extension headers and when
+	 * we are not going to fragment
+	 */
+	if (transhdrlen && sk->sk_protocol == IPPROTO_UDP &&
+	    headersize == sizeof(struct ipv6hdr) &&
+	    length < mtu - headersize &&
+	    !(flags & MSG_MORE) &&
+	    rt->dst.dev->features & NETIF_F_V6_CSUM)
+		csummode = CHECKSUM_PARTIAL;
+
 	if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_RAW) {
 		sock_tx_timestamp(sk, &tx_flags);
 		if (tx_flags & SKBTX_ANY_SW_TSTAMP &&
@@ -1326,16 +1332,6 @@ static int __ip6_append_data(struct sock *sk,
 			tskey = sk->sk_tskey++;
 	}
 
-	/* If this is the first and only packet and device
-	 * supports checksum offloading, let's use it.
-	 * Use transhdrlen, same as IPv4, because partial
-	 * sums only work when transhdrlen is set.
-	 */
-	if (transhdrlen && sk->sk_protocol == IPPROTO_UDP &&
-	    length + fragheaderlen < mtu &&
-	    rt->dst.dev->features & NETIF_F_V6_CSUM &&
-	    !exthdrlen)
-		csummode = CHECKSUM_PARTIAL;
 	/*
 	 * Let's try using as much space as possible.
 	 * Use MTU if total length of the message fits into the MTU.