diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6edef95fecf4b..bec451da7def4 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -123,6 +123,13 @@ struct mem_cgroup_stat_cpu {
 	unsigned long targets[MEM_CGROUP_NTARGETS];
 };
 
+struct mem_cgroup_reclaim_iter {
+	/* css_id of the last scanned hierarchy member */
+	int position;
+	/* scan generation, increased every round-trip */
+	unsigned int generation;
+};
+
 /*
  * per-zone information in memory controller.
  */
@@ -133,6 +140,8 @@ struct mem_cgroup_per_zone {
 	struct list_head	lists[NR_LRU_LISTS];
 	unsigned long		count[NR_LRU_LISTS];
 
+	struct mem_cgroup_reclaim_iter reclaim_iter[DEF_PRIORITY + 1];
+
 	struct zone_reclaim_stat reclaim_stat;
 	struct rb_node		tree_node;	/* RB tree node */
 	unsigned long long	usage_in_excess;/* Set to the value by which */
@@ -233,11 +242,6 @@ struct mem_cgroup {
 	 * per zone LRU lists.
 	 */
 	struct mem_cgroup_lru_info info;
-	/*
-	 * While reclaiming in a hierarchy, we cache the last child we
-	 * reclaimed from.
-	 */
-	int last_scanned_child;
 	int last_scanned_node;
 #if MAX_NUMNODES > 1
 	nodemask_t	scan_nodes;
@@ -853,9 +857,16 @@ struct mem_cgroup *try_get_mem_cgroup_from_mm(struct mm_struct *mm)
 	return memcg;
 }
 
-static struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
-					  struct mem_cgroup *prev,
-					  bool reclaim)
+struct mem_cgroup_reclaim_cookie {
+	struct zone *zone;
+	int priority;
+	unsigned int generation;
+};
+
+static struct mem_cgroup *
+mem_cgroup_iter(struct mem_cgroup *root,
+		struct mem_cgroup *prev,
+		struct mem_cgroup_reclaim_cookie *reclaim)
 {
 	struct mem_cgroup *memcg = NULL;
 	int id = 0;
@@ -876,10 +887,20 @@ static struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 	}
 
 	while (!memcg) {
+		struct mem_cgroup_reclaim_iter *uninitialized_var(iter);
 		struct cgroup_subsys_state *css;
 
-		if (reclaim)
-			id = root->last_scanned_child;
+		if (reclaim) {
+			int nid = zone_to_nid(reclaim->zone);
+			int zid = zone_idx(reclaim->zone);
+			struct mem_cgroup_per_zone *mz;
+
+			mz = mem_cgroup_zoneinfo(root, nid, zid);
+			iter = &mz->reclaim_iter[reclaim->priority];
+			if (prev && reclaim->generation != iter->generation)
+				return NULL;
+			id = iter->position;
+		}
 
 		rcu_read_lock();
 		css = css_get_next(&mem_cgroup_subsys, id + 1, &root->css, &id);
@@ -891,8 +912,13 @@ static struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 			id = 0;
 		rcu_read_unlock();
 
-		if (reclaim)
-			root->last_scanned_child = id;
+		if (reclaim) {
+			iter->position = id;
+			if (!css)
+				iter->generation++;
+			else if (!prev && memcg)
+				reclaim->generation = iter->generation;
+		}
 
 		if (prev && !css)
 			return NULL;
@@ -915,14 +941,14 @@ static void mem_cgroup_iter_break(struct mem_cgroup *root,
  * be used for reference counting.
  */
 #define for_each_mem_cgroup_tree(iter, root)		\
-	for (iter = mem_cgroup_iter(root, NULL, false);	\
+	for (iter = mem_cgroup_iter(root, NULL, NULL);	\
 	     iter != NULL;				\
-	     iter = mem_cgroup_iter(root, iter, false))
+	     iter = mem_cgroup_iter(root, iter, NULL))
 
 #define for_each_mem_cgroup(iter)			\
-	for (iter = mem_cgroup_iter(NULL, NULL, false);	\
+	for (iter = mem_cgroup_iter(NULL, NULL, NULL);	\
 	     iter != NULL;				\
-	     iter = mem_cgroup_iter(NULL, iter, false))
+	     iter = mem_cgroup_iter(NULL, iter, NULL))
 
 static inline bool mem_cgroup_is_root(struct mem_cgroup *memcg)
 {
@@ -1692,6 +1718,10 @@ static int mem_cgroup_hierarchical_reclaim(struct mem_cgroup *root_memcg,
 	bool check_soft = reclaim_options & MEM_CGROUP_RECLAIM_SOFT;
 	unsigned long excess;
 	unsigned long nr_scanned;
+	struct mem_cgroup_reclaim_cookie reclaim = {
+		.zone = zone,
+		.priority = 0,
+	};
 
 	excess = res_counter_soft_limit_excess(&root_memcg->res) >> PAGE_SHIFT;
 
@@ -1700,7 +1730,7 @@ static int mem_cgroup_hierarchical_reclaim(struct mem_cgroup *root_memcg,
 		noswap = true;
 
 	while (1) {
-		victim = mem_cgroup_iter(root_memcg, victim, true);
+		victim = mem_cgroup_iter(root_memcg, victim, &reclaim);
 		if (!victim) {
 			loop++;
 			/*
@@ -5028,7 +5058,6 @@ mem_cgroup_create(struct cgroup_subsys *ss, struct cgroup *cont)
 		res_counter_init(&memcg->res, NULL);
 		res_counter_init(&memcg->memsw, NULL);
 	}
-	memcg->last_scanned_child = 0;
 	memcg->last_scanned_node = MAX_NUMNODES;
 	INIT_LIST_HEAD(&memcg->oom_notify);