Fix blockable memory allocation under rcu_read_lock in psi monitors

psi_fop_poll calls poll_wait which makes a blockable allocation. This
generates the following warning:

[   22.339969] c0    898 BUG: sleeping function called from invalid context at mm/slab.h:393
[   22.339972] c0    898 in_atomic(): 0, irqs_disabled(): 0, pid: 898, name: lmkd
[   22.339979] c0    898 ------------[ cut here ]------------
[   22.339981] c0    898 kernel BUG at kernel/sched/core.c:8448!
[   22.339986] c0    898 ------------[ cut here ]------------
[   22.339987] c0    898 kernel BUG at kernel/sched/core.c:8448!
[   22.339989] c0    898 Internal error: Oops - BUG: 0 [#1] PREEMPT SMP
[   22.339996] c0    898 Modules linked in:
[   22.340001] c0    898 CPU: 0 PID: 898 Comm: lmkd Not tainted 4.9.153-gae6c134c27aa-dirty_audio-g3dce958 #1
[   22.340002] c0    898 Hardware name: <redacted>
[   22.340004] c0    898 task: fffffff3a6696000 task.stack: fffffff3977b4000
[   22.340016] c0    898 PC is at ___might_sleep+0x104/0x10c
[   22.340019] c0    898 LR is at ___might_sleep+0xdc/0x10c
...
[   22.340201] c0    898 [<ffffff9f4bceaf2c>] ___might_sleep+0x104/0x10c
[   22.340203] c0    898 [<ffffff9f4bceae14>] __might_sleep+0x24/0x38
[   22.340209] c0    898 [<ffffff9f4be26d1c>] kmem_cache_alloc+0x60/0x338
[   22.340213] c0    898 [<ffffff9f4be9684c>] ep_ptable_queue_proc+0x38/0xc8
[   22.340219] c0    898 [<ffffff9f4bd18218>] psi_fop_poll+0x7c/0xbc

Fix this by using refcount during poll_wait and trigger destruction. This prevents trigger
from being destroyed during poll_wait call.

Bug: 124796537
Change-Id: Ib7feb96990f1cd96f1ab0e0720010ffa442c25d9
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
diff --git a/include/linux/psi.h b/include/linux/psi.h
index c85e284..b1f0c2c 100644
--- a/include/linux/psi.h
+++ b/include/linux/psi.h
@@ -30,9 +30,9 @@
 
 struct psi_trigger *psi_trigger_create(struct psi_group *group,
 			char *buf, size_t nbytes, enum psi_res res);
-void psi_trigger_destroy(struct psi_trigger *t);
+void psi_trigger_replace(void **trigger_ptr, struct psi_trigger *t);
 
-unsigned int psi_trigger_poll(struct psi_trigger *t, struct file *file,
+unsigned int psi_trigger_poll(void **trigger_ptr, struct file *file,
 			poll_table *wait);
 #endif
 
diff --git a/include/linux/psi_types.h b/include/linux/psi_types.h
index 9275052..23396ee 100644
--- a/include/linux/psi_types.h
+++ b/include/linux/psi_types.h
@@ -3,6 +3,7 @@
 
 #include <linux/seqlock.h>
 #include <linux/types.h>
+#include <linux/kref.h>
 #include <linux/wait.h>
 
 #ifdef CONFIG_PSI
@@ -111,6 +112,9 @@
 	 * events to one per window
 	 */
 	u64 last_event_time;
+
+	/* Refcounting to prevent premature destruction */
+	struct kref refcount;
 };
 
 struct psi_group {
diff --git a/kernel/cgroup.c b/kernel/cgroup.c
index 7c5fe44..ecb34d52 100644
--- a/kernel/cgroup.c
+++ b/kernel/cgroup.c
@@ -3523,7 +3523,6 @@
 static ssize_t cgroup_pressure_write(struct kernfs_open_file *of, char *buf,
 					  size_t nbytes, enum psi_res res)
 {
-	struct psi_trigger *old;
 	struct psi_trigger *new;
 	struct cgroup *cgrp;
 
@@ -3540,12 +3539,7 @@
 		return PTR_ERR(new);
 	}
 
-	old = of->priv;
-	rcu_assign_pointer(of->priv, new);
-	if (old) {
-		synchronize_rcu();
-		psi_trigger_destroy(old);
-	}
+	psi_trigger_replace(&of->priv, new);
 
 	cgroup_put(cgrp);
 
@@ -3576,30 +3570,12 @@
 static unsigned int cgroup_pressure_poll(struct kernfs_open_file *of,
 					  poll_table *pt)
 {
-	struct psi_trigger *t;
-	unsigned int ret;
-
-	rcu_read_lock();
-	t = rcu_dereference(of->priv);
-	if (t)
-		ret = psi_trigger_poll(t, of->file, pt);
-	else
-		ret = DEFAULT_POLLMASK | POLLERR | POLLPRI;
-	rcu_read_unlock();
-
-	return ret;
+	return psi_trigger_poll(&of->priv, of->file, pt);
 }
 
 static void cgroup_pressure_release(struct kernfs_open_file *of)
 {
-	struct psi_trigger *t = of->priv;
-
-	if (!t)
-		return;
-
-	rcu_assign_pointer(of->priv, NULL);
-	synchronize_rcu();
-	psi_trigger_destroy(t);
+	psi_trigger_replace(&of->priv, NULL);
 }
 #endif /* CONFIG_PSI */
 
diff --git a/kernel/sched/psi.c b/kernel/sched/psi.c
index a6562baa..b6428aa 100644
--- a/kernel/sched/psi.c
+++ b/kernel/sched/psi.c
@@ -1103,6 +1103,7 @@
 	t->event = 0;
 	t->last_event_time = 0;
 	init_waitqueue_head(&t->event_wait);
+	kref_init(&t->refcount);
 
 	mutex_lock(&group->update_lock);
 
@@ -1117,14 +1118,22 @@
 	return t;
 }
 
-void psi_trigger_destroy(struct psi_trigger *t)
+static void psi_trigger_destroy(struct kref *ref)
 {
+	struct psi_trigger *t = container_of(ref, struct psi_trigger, refcount);
 	struct psi_group *group = t->group;
 
 	if (static_branch_likely(&psi_disabled))
 		return;
 
+	/*
+	 * Wakeup waiters to stop polling. Can happen if cgroup is deleted
+	 * from under a polling process.
+	 */
+	wake_up_interruptible(&t->event_wait);
+
 	mutex_lock(&group->update_lock);
+
 	if (!list_empty(&t->node)) {
 		struct psi_trigger *tmp;
 		u64 period = ULLONG_MAX;
@@ -1139,30 +1148,57 @@
 					UPDATES_PER_WINDOW));
 		}
 		group->trigger_min_period = period;
-		/*
-		 * Wakeup waiters to stop polling.
-		 * Can happen if cgroup is deleted from under
-		 * a polling process.
-		 */
-		wake_up_interruptible(&t->event_wait);
-		kfree(t);
 	}
+
 	mutex_unlock(&group->update_lock);
+	/*
+	 * Wait for RCU to complete its read-side critical section before
+	 * destroying the trigger
+	 */
+	synchronize_rcu();
+	kfree(t);
 }
 
-unsigned int psi_trigger_poll(struct psi_trigger *t,
-				struct file *file, poll_table *wait)
+void psi_trigger_replace(void **trigger_ptr, struct psi_trigger *new)
 {
+	struct psi_trigger *old = *trigger_ptr;
+
+	if (static_branch_likely(&psi_disabled))
+		return;
+
+	rcu_assign_pointer(*trigger_ptr, new);
+	if (old)
+		kref_put(&old->refcount, psi_trigger_destroy);
+}
+
+unsigned int psi_trigger_poll(void **trigger_ptr, struct file *file,
+			      poll_table *wait)
+{
+	unsigned int ret = DEFAULT_POLLMASK;
+	struct psi_trigger *t;
+
 	if (static_branch_likely(&psi_disabled))
 		return DEFAULT_POLLMASK | POLLERR | POLLPRI;
 
+	rcu_read_lock();
+
+	t = rcu_dereference(*(void __rcu __force **)trigger_ptr);
+	if (!t) {
+		rcu_read_unlock();
+		return DEFAULT_POLLMASK | POLLERR | POLLPRI;
+	}
+	kref_get(&t->refcount);
+
+	rcu_read_unlock();
+
 	poll_wait(file, &t->event_wait, wait);
 
 	if (cmpxchg(&t->event, 1, 0) == 1)
-		return DEFAULT_POLLMASK | POLLPRI;
+		ret |= POLLPRI;
 
-	/* Wait */
-	return DEFAULT_POLLMASK;
+	kref_put(&t->refcount, psi_trigger_destroy);
+
+	return ret;
 }
 
 static ssize_t psi_write(struct file *file, const char __user *user_buf,
@@ -1171,7 +1207,6 @@
 	char buf[32];
 	size_t buf_size;
 	struct seq_file *seq;
-	struct psi_trigger *old;
 	struct psi_trigger *new;
 
 	if (static_branch_likely(&psi_disabled))
@@ -1190,15 +1225,9 @@
 	seq = file->private_data;
 	/* Take seq->lock to protect seq->private from concurrent writes */
 	mutex_lock(&seq->lock);
-	old = seq->private;
-	rcu_assign_pointer(seq->private, new);
+	psi_trigger_replace(&seq->private, new);
 	mutex_unlock(&seq->lock);
 
-	if (old) {
-		synchronize_rcu();
-		psi_trigger_destroy(old);
-	}
-
 	return nbytes;
 }
 
@@ -1223,33 +1252,15 @@
 static unsigned int psi_fop_poll(struct file *file, poll_table *wait)
 {
 	struct seq_file *seq = file->private_data;
-	struct psi_trigger *t;
-	unsigned int ret;
 
-	rcu_read_lock();
-	t = rcu_dereference(seq->private);
-	if (t)
-		ret = psi_trigger_poll(t, file, wait);
-	else
-		ret = DEFAULT_POLLMASK | POLLERR | POLLPRI;
-	rcu_read_unlock();
-
-	return ret;
-
+	return psi_trigger_poll(&seq->private, file, wait);
 }
 
 static int psi_fop_release(struct inode *inode, struct file *file)
 {
 	struct seq_file *seq = file->private_data;
-	struct psi_trigger *t = seq->private;
 
-	if (static_branch_likely(&psi_disabled) || !t)
-		goto out;
-
-	rcu_assign_pointer(seq->private, NULL);
-	synchronize_rcu();
-	psi_trigger_destroy(t);
-out:
+	psi_trigger_replace(&seq->private, NULL);
 	return single_release(inode, file);
 }