diff --git a/include/linux/cgroup.h b/include/linux/cgroup.h
index 5c6018fef5aa6..c9fdf6f57913c 100644
--- a/include/linux/cgroup.h
+++ b/include/linux/cgroup.h
@@ -696,6 +696,7 @@ static inline void cgroup_path_from_kernfs_id(const union kernfs_node_id *id,
  */
 void cgroup_rstat_updated(struct cgroup *cgrp, int cpu);
 void cgroup_rstat_flush(struct cgroup *cgrp);
+void cgroup_rstat_flush_irqsafe(struct cgroup *cgrp);
 void cgroup_rstat_flush_hold(struct cgroup *cgrp);
 void cgroup_rstat_flush_release(void);
 
diff --git a/kernel/cgroup/rstat.c b/kernel/cgroup/rstat.c
index d49bf92ac3d47..3386fb251a9ee 100644
--- a/kernel/cgroup/rstat.c
+++ b/kernel/cgroup/rstat.c
@@ -2,7 +2,7 @@
 
 #include <linux/sched/cputime.h>
 
-static DEFINE_MUTEX(cgroup_rstat_mutex);
+static DEFINE_SPINLOCK(cgroup_rstat_lock);
 static DEFINE_PER_CPU(raw_spinlock_t, cgroup_rstat_cpu_lock);
 
 static void cgroup_base_stat_flush(struct cgroup *cgrp, int cpu);
@@ -132,21 +132,31 @@ static struct cgroup *cgroup_rstat_cpu_pop_updated(struct cgroup *pos,
 }
 
 /* see cgroup_rstat_flush() */
-static void cgroup_rstat_flush_locked(struct cgroup *cgrp)
+static void cgroup_rstat_flush_locked(struct cgroup *cgrp, bool may_sleep)
+	__releases(&cgroup_rstat_lock) __acquires(&cgroup_rstat_lock)
 {
 	int cpu;
 
-	lockdep_assert_held(&cgroup_rstat_mutex);
+	lockdep_assert_held(&cgroup_rstat_lock);
 
 	for_each_possible_cpu(cpu) {
 		raw_spinlock_t *cpu_lock = per_cpu_ptr(&cgroup_rstat_cpu_lock,
 						       cpu);
 		struct cgroup *pos = NULL;
 
-		raw_spin_lock_irq(cpu_lock);
+		raw_spin_lock(cpu_lock);
 		while ((pos = cgroup_rstat_cpu_pop_updated(pos, cgrp, cpu)))
 			cgroup_base_stat_flush(pos, cpu);
-		raw_spin_unlock_irq(cpu_lock);
+		raw_spin_unlock(cpu_lock);
+
+		/* if @may_sleep, play nice and yield if necessary */
+		if (may_sleep && (need_resched() ||
+				  spin_needbreak(&cgroup_rstat_lock))) {
+			spin_unlock_irq(&cgroup_rstat_lock);
+			if (!cond_resched())
+				cpu_relax();
+			spin_lock_irq(&cgroup_rstat_lock);
+		}
 	}
 }
 
@@ -160,12 +170,31 @@ static void cgroup_rstat_flush_locked(struct cgroup *cgrp)
  *
  * This also gets all cgroups in the subtree including @cgrp off the
  * ->updated_children lists.
+ *
+ * This function may block.
  */
 void cgroup_rstat_flush(struct cgroup *cgrp)
 {
-	mutex_lock(&cgroup_rstat_mutex);
-	cgroup_rstat_flush_locked(cgrp);
-	mutex_unlock(&cgroup_rstat_mutex);
+	might_sleep();
+
+	spin_lock_irq(&cgroup_rstat_lock);
+	cgroup_rstat_flush_locked(cgrp, true);
+	spin_unlock_irq(&cgroup_rstat_lock);
+}
+
+/**
+ * cgroup_rstat_flush_irqsafe - irqsafe version of cgroup_rstat_flush()
+ * @cgrp: target cgroup
+ *
+ * This function can be called from any context.
+ */
+void cgroup_rstat_flush_irqsafe(struct cgroup *cgrp)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&cgroup_rstat_lock, flags);
+	cgroup_rstat_flush_locked(cgrp, false);
+	spin_unlock_irqrestore(&cgroup_rstat_lock, flags);
 }
 
 /**
@@ -174,21 +203,24 @@ void cgroup_rstat_flush(struct cgroup *cgrp)
  *
  * Flush stats in @cgrp's subtree and prevent further flushes.  Must be
  * paired with cgroup_rstat_flush_release().
+ *
+ * This function may block.
  */
 void cgroup_rstat_flush_hold(struct cgroup *cgrp)
-	__acquires(&cgroup_rstat_mutex)
+	__acquires(&cgroup_rstat_lock)
 {
-	mutex_lock(&cgroup_rstat_mutex);
-	cgroup_rstat_flush_locked(cgrp);
+	might_sleep();
+	spin_lock_irq(&cgroup_rstat_lock);
+	cgroup_rstat_flush_locked(cgrp, true);
 }
 
 /**
  * cgroup_rstat_flush_release - release cgroup_rstat_flush_hold()
  */
 void cgroup_rstat_flush_release(void)
-	__releases(&cgroup_rstat_mutex)
+	__releases(&cgroup_rstat_lock)
 {
-	mutex_unlock(&cgroup_rstat_mutex);
+	spin_unlock_irq(&cgroup_rstat_lock);
 }
 
 int cgroup_rstat_init(struct cgroup *cgrp)