diff --git a/drivers/net/phy/phy_device.c b/drivers/net/phy/phy_device.c
index 0d8f4d3847f66..8c8e15b8739de 100644
--- a/drivers/net/phy/phy_device.c
+++ b/drivers/net/phy/phy_device.c
@@ -908,6 +908,7 @@ int phy_attach_direct(struct net_device *dev, struct phy_device *phydev,
 	struct module *ndev_owner = dev->dev.parent->driver->owner;
 	struct mii_bus *bus = phydev->mdio.bus;
 	struct device *d = &phydev->mdio.dev;
+	bool using_genphy = false;
 	int err;
 
 	/* For Ethernet device drivers that register their own MDIO bus, we
@@ -920,11 +921,6 @@ int phy_attach_direct(struct net_device *dev, struct phy_device *phydev,
 		return -EIO;
 	}
 
-	if (!try_module_get(d->driver->owner)) {
-		dev_err(&dev->dev, "failed to get the device driver module\n");
-		return -EIO;
-	}
-
 	get_device(d);
 
 	/* Assume that if there is no driver, that it doesn't
@@ -938,12 +934,22 @@ int phy_attach_direct(struct net_device *dev, struct phy_device *phydev,
 			d->driver =
 				&genphy_driver[GENPHY_DRV_1G].mdiodrv.driver;
 
+		using_genphy = true;
+	}
+
+	if (!try_module_get(d->driver->owner)) {
+		dev_err(&dev->dev, "failed to get the device driver module\n");
+		err = -EIO;
+		goto error_put_device;
+	}
+
+	if (using_genphy) {
 		err = d->driver->probe(d);
 		if (err >= 0)
 			err = device_bind_driver(d);
 
 		if (err)
-			goto error;
+			goto error_module_put;
 	}
 
 	if (phydev->attached_dev) {
@@ -980,9 +986,14 @@ int phy_attach_direct(struct net_device *dev, struct phy_device *phydev,
 	return err;
 
 error:
+	/* phy_detach() does all of the cleanup below */
 	phy_detach(phydev);
-	put_device(d);
+	return err;
+
+error_module_put:
 	module_put(d->driver->owner);
+error_put_device:
+	put_device(d);
 	if (ndev_owner != bus->owner)
 		module_put(bus->owner);
 	return err;
@@ -1045,6 +1056,8 @@ void phy_detach(struct phy_device *phydev)
 
 	phy_led_triggers_unregister(phydev);
 
+	module_put(phydev->mdio.dev.driver->owner);
+
 	/* If the device had no specific driver before (i.e. - it
 	 * was using the generic driver), we unbind the device
 	 * from the generic driver so that there's a chance a
@@ -1065,7 +1078,6 @@ void phy_detach(struct phy_device *phydev)
 	bus = phydev->mdio.bus;
 
 	put_device(&phydev->mdio.dev);
-	module_put(phydev->mdio.dev.driver->owner);
 	if (ndev_owner != bus->owner)
 		module_put(bus->owner);
 }