diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index f40a2167935f2..10a05062e4a0a 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -377,6 +377,7 @@ struct ieee80211_mgd_auth_data {
 	u8 key[WLAN_KEY_LEN_WEP104];
 	u8 key_len, key_idx;
 	bool done;
+	bool peer_confirmed;
 	bool timeout_started;
 
 	u16 sae_trans, sae_status;
diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c
index 1818dbc5622d6..d2bc8d57c87eb 100644
--- a/net/mac80211/mlme.c
+++ b/net/mac80211/mlme.c
@@ -2764,8 +2764,15 @@ static void ieee80211_auth_challenge(struct ieee80211_sub_if_data *sdata,
 static bool ieee80211_mark_sta_auth(struct ieee80211_sub_if_data *sdata,
 				    const u8 *bssid)
 {
+	struct ieee80211_if_managed *ifmgd = &sdata->u.mgd;
 	struct sta_info *sta;
 
+	sdata_info(sdata, "authenticated\n");
+	ifmgd->auth_data->done = true;
+	ifmgd->auth_data->timeout = jiffies + IEEE80211_AUTH_WAIT_ASSOC;
+	ifmgd->auth_data->timeout_started = true;
+	run_again(sdata, ifmgd->auth_data->timeout);
+
 	/* move station state to auth */
 	mutex_lock(&sdata->local->sta_mtx);
 	sta = sta_info_get(sdata, bssid);
@@ -2811,7 +2818,11 @@ static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata,
 	status_code = le16_to_cpu(mgmt->u.auth.status_code);
 
 	if (auth_alg != ifmgd->auth_data->algorithm ||
-	    auth_transaction != ifmgd->auth_data->expected_transaction) {
+	    (auth_alg != WLAN_AUTH_SAE &&
+	     auth_transaction != ifmgd->auth_data->expected_transaction) ||
+	    (auth_alg == WLAN_AUTH_SAE &&
+	     (auth_transaction < ifmgd->auth_data->expected_transaction ||
+	      auth_transaction > 2))) {
 		sdata_info(sdata, "%pM unexpected authentication state: alg %d (expected %d) transact %d (expected %d)\n",
 			   mgmt->sa, auth_alg, ifmgd->auth_data->algorithm,
 			   auth_transaction,
@@ -2854,25 +2865,17 @@ static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata,
 
 	event.u.mlme.status = MLME_SUCCESS;
 	drv_event_callback(sdata->local, sdata, &event);
-	sdata_info(sdata, "authenticated\n");
-	ifmgd->auth_data->done = true;
-	ifmgd->auth_data->timeout = jiffies + IEEE80211_AUTH_WAIT_ASSOC;
-	ifmgd->auth_data->timeout_started = true;
-	run_again(sdata, ifmgd->auth_data->timeout);
-
-	if (ifmgd->auth_data->algorithm == WLAN_AUTH_SAE &&
-	    ifmgd->auth_data->expected_transaction != 2) {
-		/*
-		 * Report auth frame to user space for processing since another
-		 * round of Authentication frames is still needed.
-		 */
-		cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len);
-		return;
+	if (ifmgd->auth_data->algorithm != WLAN_AUTH_SAE ||
+	    (auth_transaction == 2 &&
+	     ifmgd->auth_data->expected_transaction == 2)) {
+		if (!ieee80211_mark_sta_auth(sdata, bssid))
+			goto out_err;
+	} else if (ifmgd->auth_data->algorithm == WLAN_AUTH_SAE &&
+		   auth_transaction == 2) {
+		sdata_info(sdata, "SAE peer confirmed\n");
+		ifmgd->auth_data->peer_confirmed = true;
 	}
 
-	if (!ieee80211_mark_sta_auth(sdata, bssid))
-		goto out_err;
-
 	cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len);
 	return;
  out_err:
@@ -4888,6 +4891,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 	struct ieee80211_mgd_auth_data *auth_data;
 	u16 auth_alg;
 	int err;
+	bool cont_auth;
 
 	/* prepare auth data structure */
 
@@ -4922,8 +4926,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 		return -EOPNOTSUPP;
 	}
 
-	if ((ifmgd->auth_data && !ifmgd->auth_data->done) ||
-	    ifmgd->assoc_data)
+	if (ifmgd->assoc_data)
 		return -EBUSY;
 
 	auth_data = kzalloc(sizeof(*auth_data) + req->auth_data_len +
@@ -4945,6 +4948,13 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 		auth_data->data_len += req->auth_data_len - 4;
 	}
 
+	/* Check if continuing authentication or trying to authenticate with the
+	 * same BSS that we were in the process of authenticating with and avoid
+	 * removal and re-addition of the STA entry in
+	 * ieee80211_prep_connection().
+	 */
+	cont_auth = ifmgd->auth_data && req->bss == ifmgd->auth_data->bss;
+
 	if (req->ie && req->ie_len) {
 		memcpy(&auth_data->data[auth_data->data_len],
 		       req->ie, req->ie_len);
@@ -4961,12 +4971,26 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 
 	/* try to authenticate/probe */
 
-	if (ifmgd->auth_data)
-		ieee80211_destroy_auth_data(sdata, false);
+	if (ifmgd->auth_data) {
+		if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE) {
+			auth_data->peer_confirmed =
+				ifmgd->auth_data->peer_confirmed;
+		}
+		ieee80211_destroy_auth_data(sdata, cont_auth);
+	}
 
 	/* prep auth_data so we don't go into idle on disassoc */
 	ifmgd->auth_data = auth_data;
 
+	/* If this is continuation of an ongoing SAE authentication exchange
+	 * (i.e., request to send SAE Confirm) and the peer has already
+	 * confirmed, mark authentication completed since we are about to send
+	 * out SAE Confirm.
+	 */
+	if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE &&
+	    auth_data->peer_confirmed && auth_data->sae_trans == 2)
+		ieee80211_mark_sta_auth(sdata, req->bss->bssid);
+
 	if (ifmgd->associated) {
 		u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN];
 
@@ -4984,7 +5008,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 
 	sdata_info(sdata, "authenticate with %pM\n", req->bss->bssid);
 
-	err = ieee80211_prep_connection(sdata, req->bss, false, false);
+	err = ieee80211_prep_connection(sdata, req->bss, cont_auth, false);
 	if (err)
 		goto err_clear;