diff --git a/drivers/mtd/maps/sa1100-flash.c b/drivers/mtd/maps/sa1100-flash.c
index 1920bcbc05d7a..50a1d434906a7 100644
--- a/drivers/mtd/maps/sa1100-flash.c
+++ b/drivers/mtd/maps/sa1100-flash.c
@@ -223,7 +223,7 @@ static int sa1100_probe_subdev(struct sa_subdev_info *subdev, struct resource *r
 	return ret;
 }
 
-static void sa1100_destroy(struct sa_info *info)
+static void sa1100_destroy(struct sa_info *info, struct flash_platform_data *plat)
 {
 	int i;
 
@@ -242,6 +242,9 @@ static void sa1100_destroy(struct sa_info *info)
 	for (i = info->num_subdev - 1; i >= 0; i--)
 		sa1100_destroy_subdev(&info->subdev[i]);
 	kfree(info);
+
+	if (plat->exit)
+		plat->exit();
 }
 
 static struct sa_info *__init
@@ -275,6 +278,12 @@ sa1100_setup_mtd(struct platform_device *pdev, struct flash_platform_data *plat)
 
 	memset(info, 0, size);
 
+	if (plat->init) {
+		ret = plat->init();
+		if (ret)
+			goto err;
+	}
+
 	/*
 	 * Claim and then map the memory regions.
 	 */
@@ -336,7 +345,7 @@ sa1100_setup_mtd(struct platform_device *pdev, struct flash_platform_data *plat)
 		return info;
 
  err:
-	sa1100_destroy(info);
+	sa1100_destroy(info, plat);
  out:
 	return ERR_PTR(ret);
 }
@@ -397,8 +406,11 @@ static int __init sa1100_mtd_probe(struct device *dev)
 static int __exit sa1100_mtd_remove(struct device *dev)
 {
 	struct sa_info *info = dev_get_drvdata(dev);
+	struct flash_platform_data *plat = dev->platform_data;
+
 	dev_set_drvdata(dev, NULL);
-	sa1100_destroy(info);
+	sa1100_destroy(info, plat);
+
 	return 0;
 }