diff --git a/drivers/gpu/drm/drm_dp_mst_topology.c b/drivers/gpu/drm/drm_dp_mst_topology.c
index 112972031a84f..ed0fea2ac3223 100644
--- a/drivers/gpu/drm/drm_dp_mst_topology.c
+++ b/drivers/gpu/drm/drm_dp_mst_topology.c
@@ -4830,41 +4830,102 @@ static bool drm_dp_mst_port_downstream_of_branch(struct drm_dp_mst_port *port,
 	return false;
 }
 
-static inline
-int drm_dp_mst_atomic_check_bw_limit(struct drm_dp_mst_branch *branch,
-				     struct drm_dp_mst_topology_state *mst_state)
+static int
+drm_dp_mst_atomic_check_port_bw_limit(struct drm_dp_mst_port *port,
+				      struct drm_dp_mst_topology_state *state);
+
+static int
+drm_dp_mst_atomic_check_mstb_bw_limit(struct drm_dp_mst_branch *mstb,
+				      struct drm_dp_mst_topology_state *state)
 {
-	struct drm_dp_mst_port *port;
 	struct drm_dp_vcpi_allocation *vcpi;
-	int pbn_limit = 0, pbn_used = 0;
+	struct drm_dp_mst_port *port;
+	int pbn_used = 0, ret;
+	bool found = false;
 
-	list_for_each_entry(port, &branch->ports, next) {
-		if (port->mstb)
-			if (drm_dp_mst_atomic_check_bw_limit(port->mstb, mst_state))
-				return -ENOSPC;
+	/* Check that we have at least one port in our state that's downstream
+	 * of this branch, otherwise we can skip this branch
+	 */
+	list_for_each_entry(vcpi, &state->vcpis, next) {
+		if (!vcpi->pbn ||
+		    !drm_dp_mst_port_downstream_of_branch(vcpi->port, mstb))
+			continue;
 
-		if (port->full_pbn > 0)
-			pbn_limit = port->full_pbn;
+		found = true;
+		break;
 	}
-	DRM_DEBUG_ATOMIC("[MST BRANCH:%p] branch has %d PBN available\n",
-			 branch, pbn_limit);
+	if (!found)
+		return 0;
 
-	list_for_each_entry(vcpi, &mst_state->vcpis, next) {
-		if (!vcpi->pbn)
-			continue;
+	if (mstb->port_parent)
+		DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] Checking bandwidth limits on [MSTB:%p]\n",
+				 mstb->port_parent->parent, mstb->port_parent,
+				 mstb);
+	else
+		DRM_DEBUG_ATOMIC("[MSTB:%p] Checking bandwidth limits\n",
+				 mstb);
+
+	list_for_each_entry(port, &mstb->ports, next) {
+		ret = drm_dp_mst_atomic_check_port_bw_limit(port, state);
+		if (ret < 0)
+			return ret;
 
-		if (drm_dp_mst_port_downstream_of_branch(vcpi->port, branch))
-			pbn_used += vcpi->pbn;
+		pbn_used += ret;
 	}
-	DRM_DEBUG_ATOMIC("[MST BRANCH:%p] branch used %d PBN\n",
-			 branch, pbn_used);
 
-	if (pbn_used > pbn_limit) {
-		DRM_DEBUG_ATOMIC("[MST BRANCH:%p] No available bandwidth\n",
-				 branch);
+	return pbn_used;
+}
+
+static int
+drm_dp_mst_atomic_check_port_bw_limit(struct drm_dp_mst_port *port,
+				      struct drm_dp_mst_topology_state *state)
+{
+	struct drm_dp_vcpi_allocation *vcpi;
+	int pbn_used = 0;
+
+	if (port->pdt == DP_PEER_DEVICE_NONE)
+		return 0;
+
+	if (drm_dp_mst_is_end_device(port->pdt, port->mcs)) {
+		bool found = false;
+
+		list_for_each_entry(vcpi, &state->vcpis, next) {
+			if (vcpi->port != port)
+				continue;
+			if (!vcpi->pbn)
+				return 0;
+
+			found = true;
+			break;
+		}
+		if (!found)
+			return 0;
+
+		/* This should never happen, as it means we tried to
+		 * set a mode before querying the full_pbn
+		 */
+		if (WARN_ON(!port->full_pbn))
+			return -EINVAL;
+
+		pbn_used = vcpi->pbn;
+	} else {
+		pbn_used = drm_dp_mst_atomic_check_mstb_bw_limit(port->mstb,
+								 state);
+		if (pbn_used <= 0)
+			return pbn_used;
+	}
+
+	if (pbn_used > port->full_pbn) {
+		DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] required PBN of %d exceeds port limit of %d\n",
+				 port->parent, port, pbn_used,
+				 port->full_pbn);
 		return -ENOSPC;
 	}
-	return 0;
+
+	DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] uses %d out of %d PBN\n",
+			 port->parent, port, pbn_used, port->full_pbn);
+
+	return pbn_used;
 }
 
 static inline int
@@ -5062,9 +5123,15 @@ int drm_dp_mst_atomic_check(struct drm_atomic_state *state)
 		ret = drm_dp_mst_atomic_check_vcpi_alloc_limit(mgr, mst_state);
 		if (ret)
 			break;
-		ret = drm_dp_mst_atomic_check_bw_limit(mgr->mst_primary, mst_state);
-		if (ret)
+
+		mutex_lock(&mgr->lock);
+		ret = drm_dp_mst_atomic_check_mstb_bw_limit(mgr->mst_primary,
+							    mst_state);
+		mutex_unlock(&mgr->lock);
+		if (ret < 0)
 			break;
+		else
+			ret = 0;
 	}
 
 	return ret;