diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 0e13e566581c2..9188eae61e929 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -1718,19 +1718,6 @@ struct iommu_group *fsl_mc_device_group(struct device *dev)
 }
 EXPORT_SYMBOL_GPL(fsl_mc_device_group);
 
-static int iommu_get_def_domain_type(struct device *dev)
-{
-	const struct iommu_ops *ops = dev_iommu_ops(dev);
-
-	if (dev_is_pci(dev) && to_pci_dev(dev)->untrusted)
-		return IOMMU_DOMAIN_DMA;
-
-	if (ops->def_domain_type)
-		return ops->def_domain_type(dev);
-
-	return 0;
-}
-
 static struct iommu_domain *
 __iommu_group_alloc_default_domain(const struct bus_type *bus,
 				   struct iommu_group *group, int req_type)
@@ -1740,6 +1727,23 @@ __iommu_group_alloc_default_domain(const struct bus_type *bus,
 	return __iommu_domain_alloc(bus, req_type);
 }
 
+/*
+ * Returns the iommu_ops for the devices in an iommu group.
+ *
+ * It is assumed that all devices in an iommu group are managed by a single
+ * IOMMU unit. Therefore, this returns the dev_iommu_ops of the first device
+ * in the group.
+ */
+static const struct iommu_ops *group_iommu_ops(struct iommu_group *group)
+{
+	struct group_device *device =
+		list_first_entry(&group->devices, struct group_device, list);
+
+	lockdep_assert_held(&group->mutex);
+
+	return dev_iommu_ops(device->dev);
+}
+
 /*
  * req_type of 0 means "auto" which means to select a domain based on
  * iommu_def_domain_type or what the driver actually supports.
@@ -1820,40 +1824,77 @@ static int iommu_bus_notifier(struct notifier_block *nb,
 	return 0;
 }
 
-/* A target_type of 0 will select the best domain type and cannot fail */
+/*
+ * Combine the driver's chosen def_domain_type across all the devices in a
+ * group. Drivers must give a consistent result.
+ */
+static int iommu_get_def_domain_type(struct iommu_group *group,
+				     struct device *dev, int cur_type)
+{
+	const struct iommu_ops *ops = group_iommu_ops(group);
+	int type;
+
+	if (!ops->def_domain_type)
+		return cur_type;
+
+	type = ops->def_domain_type(dev);
+	if (!type || cur_type == type)
+		return cur_type;
+	if (!cur_type)
+		return type;
+
+	dev_err_ratelimited(
+		dev,
+		"IOMMU driver error, requesting conflicting def_domain_type, %s and %s, for devices in group %u.\n",
+		iommu_domain_type_str(cur_type), iommu_domain_type_str(type),
+		group->id);
+
+	/*
+	 * Try to recover, drivers are allowed to force IDENITY or DMA, IDENTITY
+	 * takes precedence.
+	 */
+	if (type == IOMMU_DOMAIN_IDENTITY)
+		return type;
+	return cur_type;
+}
+
+/*
+ * A target_type of 0 will select the best domain type. 0 can be returned in
+ * this case meaning the global default should be used.
+ */
 static int iommu_get_default_domain_type(struct iommu_group *group,
 					 int target_type)
 {
-	int best_type = target_type;
+	struct device *untrusted = NULL;
 	struct group_device *gdev;
-	struct device *last_dev;
+	int driver_type = 0;
 
 	lockdep_assert_held(&group->mutex);
-
 	for_each_group_device(group, gdev) {
-		unsigned int type = iommu_get_def_domain_type(gdev->dev);
-
-		if (best_type && type && best_type != type) {
-			if (target_type) {
-				dev_err_ratelimited(
-					gdev->dev,
-					"Device cannot be in %s domain\n",
-					iommu_domain_type_str(target_type));
-				return -1;
-			}
+		driver_type = iommu_get_def_domain_type(group, gdev->dev,
+							driver_type);
 
-			dev_warn(
-				gdev->dev,
-				"Device needs domain type %s, but device %s in the same iommu group requires type %s - using default\n",
-				iommu_domain_type_str(type), dev_name(last_dev),
-				iommu_domain_type_str(best_type));
-			return 0;
+		if (dev_is_pci(gdev->dev) && to_pci_dev(gdev->dev)->untrusted)
+			untrusted = gdev->dev;
+	}
+
+	if (untrusted) {
+		if (driver_type && driver_type != IOMMU_DOMAIN_DMA) {
+			dev_err_ratelimited(
+				untrusted,
+				"Device is not trusted, but driver is overriding group %u to %s, refusing to probe.\n",
+				group->id, iommu_domain_type_str(driver_type));
+			return -1;
 		}
-		if (!best_type)
-			best_type = type;
-		last_dev = gdev->dev;
+		driver_type = IOMMU_DOMAIN_DMA;
+	}
+
+	if (target_type) {
+		if (driver_type && target_type != driver_type)
+			return -1;
+		return target_type;
 	}
-	return best_type;
+	return driver_type;
 }
 
 static void iommu_group_do_probe_finalize(struct device *dev)