Split aggregator and monitor into separate workers

Running aggregator and monitor in separate workers lets us simplify
lockless synchronization and run monitor as an RT thread. This prevents
psi notifications from being delayed in the presence of other RT threads
which can otherwise starve psi monitor thread.

Test: Hogging CPUs with "chrt -f 50 dd if=/dev/zero of=/dev/null"
running on each core and modifying psi monitor to generate an event with a
trace every 500ms. Without this patch missing events are observed:

psimon-901   [003] ....   109.146856: psi_trigger_work: psimon: trigger work <-- 2 events per second
psimon-901   [003] ....   109.653626: psi_trigger_work: psimon: trigger work
psimon-901   [003] ....   110.160327: psi_trigger_work: psimon: trigger work
psimon-901   [003] ....   110.667015: psi_trigger_work: psimon: trigger work
psimon-901   [003] ....   111.173563: psi_trigger_work: psimon: trigger work
psimon-901   [003] ....   111.680597: psi_trigger_work: psimon: trigger work <-- test start, event rate drops
psimon-901   [004] ....   113.400104: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   114.396731: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   115.396821: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   116.396717: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   117.396734: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   118.396837: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   119.396703: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   120.396735: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   121.396701: psi_trigger_work: psimon: trigger work
psimon-901   [006] ....   122.396743: psi_trigger_work: psimon: trigger work
psimon-901   [004] ....   123.396706: psi_trigger_work: psimon: trigger work
psimon-901   [007] ....   124.396703: psi_trigger_work: psimon: trigger work
psimon-901   [005] ....   125.396704: psi_trigger_work: psimon: trigger work
psimon-901   [005] ....   126.396703: psi_trigger_work: psimon: trigger work <-- test stop, events resume at normal rate
psimon-901   [005] ....   126.906763: psi_trigger_work: psimon: trigger work
psimon-901   [005] ....   127.414235: psi_trigger_work: psimon: trigger work
psimon-901   [005] ....   127.921292: psi_trigger_work: psimon: trigger work

with this patch event rate is unaffected by CPU hoggers.

Change-Id: I7ef85deeab28ffe4cfa8bc658fadd33fe5d194de
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
diff --git a/include/linux/psi_types.h b/include/linux/psi_types.h
index 23396ee..9b34e51 100644
--- a/include/linux/psi_types.h
+++ b/include/linux/psi_types.h
@@ -1,6 +1,7 @@
 #ifndef _LINUX_PSI_TYPES_H
 #define _LINUX_PSI_TYPES_H
 
+#include <linux/kthread.h>
 #include <linux/seqlock.h>
 #include <linux/types.h>
 #include <linux/kref.h>
@@ -46,6 +47,12 @@
 	NR_PSI_STATES = 6,
 };
 
+enum psi_aggregators {
+	PSI_AVGS = 0,
+	PSI_POLL,
+	NR_PSI_AGGREGATORS,
+};
+
 struct psi_group_cpu {
 	/* 1st cacheline updated by the scheduler */
 
@@ -67,7 +74,8 @@
 	/* 2nd cacheline updated by the aggregator */
 
 	/* Delta detection against the sampling buckets */
-	u32 times_prev[NR_PSI_STATES] ____cacheline_aligned_in_smp;
+	u32 times_prev[NR_PSI_AGGREGATORS][NR_PSI_STATES]
+			____cacheline_aligned_in_smp;
 };
 
 /* PSI growth tracking window */
@@ -119,17 +127,16 @@
 
 struct psi_group {
 	/* Protects data used by the aggregator */
-	struct mutex update_lock;
+	struct mutex avgs_lock;
 
 	/* Per-cpu task state & time tracking */
 	struct psi_group_cpu __percpu *pcpu;
 
-	/* Periodic work control */
-	atomic_t polling;
-	struct delayed_work clock_work;
-
 	/* Total stall times observed */
-	u64 total[NR_PSI_STATES - 1];
+	u64 total[NR_PSI_AGGREGATORS][NR_PSI_STATES - 1];
+
+	/* Aggregator work control */
+	struct delayed_work avgs_work;
 
 	/* Running pressure averages */
 	u64 avg_total[NR_PSI_STATES - 1];
@@ -137,13 +144,20 @@
 	u64 avg_next_update;
 	unsigned long avg[NR_PSI_STATES - 1][3];
 
+	/* Monitor work control */
+	atomic_t poll_scheduled;
+	struct kthread_worker __rcu *poll_kworker;
+	struct kthread_delayed_work poll_work;
+
+	/* Protects data used by the monitor */
+	struct mutex trigger_lock;
+
 	/* Configured polling triggers */
 	struct list_head triggers;
 	u32 nr_triggers[NR_PSI_STATES - 1];
 	u32 trigger_states;
-	u64 trigger_min_period;
+	u64 poll_min_period;
 
-	/* Polling state */
 	/* Total stall times at the start of monitor activation */
 	u64 polling_total[NR_PSI_STATES - 1];
 	u64 polling_next_update;
diff --git a/kernel/sched/psi.c b/kernel/sched/psi.c
index b6428aa..8f34c5be 100644
--- a/kernel/sched/psi.c
+++ b/kernel/sched/psi.c
@@ -148,9 +148,9 @@
 DEFINE_STATIC_KEY_FALSE(psi_disabled);
 
 #ifdef CONFIG_PSI_DEFAULT_DISABLED
-bool psi_enable;
+static bool psi_enable;
 #else
-bool psi_enable = true;
+static bool psi_enable = true;
 #endif
 static int __init setup_psi(char *str)
 {
@@ -178,7 +178,7 @@
 	.pcpu = &system_group_pcpu,
 };
 
-static void psi_update_work(struct work_struct *work);
+static void psi_avgs_work(struct work_struct *work);
 
 static void group_init(struct psi_group *group)
 {
@@ -187,17 +187,19 @@
 	for_each_possible_cpu(cpu)
 		seqcount_init(&per_cpu_ptr(group->pcpu, cpu)->seq);
 	group->avg_next_update = sched_clock() + psi_period;
-	atomic_set(&group->polling, 0);
-	INIT_DELAYED_WORK(&group->clock_work, psi_update_work);
-	mutex_init(&group->update_lock);
+	atomic_set(&group->poll_scheduled, 0);
+	INIT_DELAYED_WORK(&group->avgs_work, psi_avgs_work);
+	mutex_init(&group->avgs_lock);
 	/* Init trigger-related members */
+	mutex_init(&group->trigger_lock);
 	INIT_LIST_HEAD(&group->triggers);
 	memset(group->nr_triggers, 0, sizeof(group->nr_triggers));
 	group->trigger_states = 0;
-	group->trigger_min_period = U32_MAX;
+	group->poll_min_period = U32_MAX;
 	memset(group->polling_total, 0, sizeof(group->polling_total));
 	group->polling_next_update = ULLONG_MAX;
 	group->polling_until = 0;
+	rcu_assign_pointer(group->poll_kworker, NULL);
 }
 
 void __init psi_init(void)
@@ -232,8 +234,9 @@
 	}
 }
 
-static void get_recent_times(struct psi_group *group, int cpu, u32 *times,
-							 u32 *pchanged_states)
+static void get_recent_times(struct psi_group *group, int cpu,
+			     enum psi_aggregators aggregator, u32 *times,
+			     u32 *pchanged_states)
 {
 	struct psi_group_cpu *groupc = per_cpu_ptr(group->pcpu, cpu);
 	u64 now, state_start;
@@ -267,8 +270,8 @@
 		if (state_mask & (1 << s))
 			times[s] += now - state_start;
 
-		delta = times[s] - groupc->times_prev[s];
-		groupc->times_prev[s] = times[s];
+		delta = times[s] - groupc->times_prev[aggregator][s];
+		groupc->times_prev[aggregator][s] = times[s];
 
 		times[s] = delta;
 		if (delta)
@@ -296,7 +299,9 @@
 	avg[2] = calc_load(avg[2], EXP_300s, pct);
 }
 
-static void collect_percpu_times(struct psi_group *group, u32 *pchanged_states)
+static void collect_percpu_times(struct psi_group *group,
+				 enum psi_aggregators aggregator,
+				 u32 *pchanged_states)
 {
 	u64 deltas[NR_PSI_STATES - 1] = { 0, };
 	unsigned long nonidle_total = 0;
@@ -317,7 +322,8 @@
 		u32 nonidle;
 		u32 cpu_changed_states;
 
-		get_recent_times(group, cpu, times, &cpu_changed_states);
+		get_recent_times(group, cpu, aggregator, times,
+				&cpu_changed_states);
 		changed_states |= cpu_changed_states;
 
 		nonidle = nsecs_to_jiffies(times[PSI_NONIDLE]);
@@ -341,7 +347,8 @@
 
 	/* total= */
 	for (s = 0; s < NR_PSI_STATES - 1; s++)
-		group->total[s] += div_u64(deltas[s], max(nonidle_total, 1UL));
+		group->total[aggregator][s] +=
+				div_u64(deltas[s], max(nonidle_total, 1UL));
 
 	if (pchanged_states)
 		*pchanged_states = changed_states;
@@ -373,7 +380,7 @@
 	for (s = 0; s < NR_PSI_STATES - 1; s++) {
 		u32 sample;
 
-		sample = group->total[s] - group->avg_total[s];
+		sample = group->total[PSI_AVGS][s] - group->avg_total[s];
 		/*
 		 * Due to the lockless sampling of the time buckets,
 		 * recorded time deltas can slip into the next period,
@@ -402,7 +409,7 @@
 
 /* Trigger tracking window manupulations */
 static void window_reset(struct psi_window *win, u64 now, u64 value,
-						 u64 prev_growth)
+			 u64 prev_growth)
 {
 	win->start_time = now;
 	win->start_value = value;
@@ -451,16 +458,18 @@
 	struct psi_trigger *t;
 
 	list_for_each_entry(t, &group->triggers, node)
-		window_reset(&t->win, now, group->total[t->state], 0);
-	memcpy(group->polling_total, group->total,
+		window_reset(&t->win, now,
+				group->total[PSI_POLL][t->state], 0);
+	memcpy(group->polling_total, group->total[PSI_POLL],
 		   sizeof(group->polling_total));
-	group->polling_next_update = now + group->trigger_min_period;
+	group->polling_next_update = now + group->poll_min_period;
 }
 
 static u64 update_triggers(struct psi_group *group, u64 now)
 {
 	struct psi_trigger *t;
 	bool new_stall = false;
+	u64 *total = group->total[PSI_POLL];
 
 	/*
 	 * On subsequent updates, calculate growth deltas and let
@@ -470,7 +479,7 @@
 		u64 growth;
 
 		/* Check for stall activity */
-		if (group->polling_total[t->state] == group->total[t->state])
+		if (group->polling_total[t->state] == total[t->state])
 			continue;
 
 		/*
@@ -482,7 +491,7 @@
 		new_stall = true;
 
 		/* Calculate growth since last update */
-		growth = window_update(&t->win, now, group->total[t->state]);
+		growth = window_update(&t->win, now, total[t->state]);
 		if (growth < t->threshold)
 			continue;
 
@@ -496,87 +505,94 @@
 		t->last_event_time = now;
 	}
 
-	if (new_stall) {
-		memcpy(group->polling_total, group->total,
-			   sizeof(group->polling_total));
-	}
+	if (new_stall)
+		memcpy(group->polling_total, total,
+				sizeof(group->polling_total));
 
-	return now + group->trigger_min_period;
+	return now + group->poll_min_period;
 }
 
-/*
- * psi_update_work represents slowpath accounting part while psi_group_change
- * represents hotpath part. There are two potential races between them:
- * 1. Changes to group->polling when slowpath checks for new stall, then hotpath
- *    records new stall and then slowpath resets group->polling flag. This leads
- *    to the exit from the polling mode while monitored state is still changing.
- * 2. Slowpath overwriting an immediate update scheduled from the hotpath with
- *    a regular update further in the future and missing the immediate update.
- * Both races are handled with a retry cycle in the slowpath:
- *
- *    HOTPATH:                         |    SLOWPATH:
- *                                     |
- * A) times[cpu] += delta              | E) delta = times[*]
- * B) start_poll = (delta[poll_mask] &&|    polling = g->polling
- *      cmpxchg(g->polling, 0, 1) == 0)|    if delta[poll_mask]:
- *    if start_poll:                   | F)   polling_until = now + grace_period
- * C)   mod_delayed_work(1)            |    if now > polling_until:
- *     else if !delayed_work_pending():|      if polling:
- * D)   schedule_delayed_work(PSI_FREQ)| G)     g->polling = polling = 0
- *                                     |        smp_mb
- *                                     | H)     goto SLOWPATH
- *                                     |    else:
- *                                     |      if !polling:
- *                                     | I)     g->polling = polling = 1
- *                                     | J) if delta && first_pass:
- *                                     |      next_avg = update_averages()
- *                                     |      if polling:
- *                                     |        next_poll = update_triggers()
- *                                     |    if (delta && first_pass) || polling:
- *                                     | K)   mod_delayed_work(
- *                                     |          min(next_avg, next_poll))
- *                                     |      if !polling:
- *                                     |        first_pass = false
- *                                     | L)     goto SLOWPATH
- *
- * Race #1 is represented by (EABGD) sequence in which case slowpath deactivates
- * polling mode because it misses new monitored stall and hotpath doesn't
- * activate it because at (B) g->polling is not yet reset by slowpath in (G).
- * This race is handled by the (H) retry, which in the race described above
- * results in the new sequence of (EABGDHEIK) that reactivates polling mode.
- *
- * Race #2 is represented by polling==false && (JABCK) sequence which overwrites
- * immediate update scheduled at (C) with a later (next_avg) update scheduled at
- * (K). This race is handled by the (L) retry which results in the new sequence
- * of polling==false && (JABCKLEIK) that reactivates polling mode and
- * reschedules the next polling update (next_poll).
- *
- * Note that retries can't result in an infinite loop because retry #1 happens
- * only during polling reactivation and retry #2 happens only on the first pass.
- * Constant reactivations are impossible because polling will stay active for at
- * least grace_period. Worst case scenario involves two retries (HEJKLE)
- */
-static void psi_update_work(struct work_struct *work)
+static void psi_avgs_work(struct work_struct *work)
 {
 	struct delayed_work *dwork;
 	struct psi_group *group;
-	bool first_pass = true;
-	u64 next_update;
 	u32 changed_states;
-	int polling;
 	bool nonidle;
 	u64 now;
 
 	dwork = to_delayed_work(work);
-	group = container_of(dwork, struct psi_group, clock_work);
+	group = container_of(dwork, struct psi_group, avgs_work);
 
-	mutex_lock(&group->update_lock);
+	mutex_lock(&group->avgs_lock);
 
 	now = sched_clock();
 
-retry:
-	collect_percpu_times(group, &changed_states);
-	polling = atomic_read(&group->polling);
+	collect_percpu_times(group, PSI_AVGS, &changed_states);
+	nonidle = changed_states & (1 << PSI_NONIDLE);
+	/*
+	 * If there is task activity, periodically fold the per-cpu
+	 * times and feed samples into the running averages. If things
+	 * are idle and there is no data to process, stop the clock.
+	 * Once restarted, we'll catch up the running averages in one
+	 * go - see calc_avgs() and missed_periods.
+	 */
+	if (nonidle) {
+		if (now >= group->avg_next_update)
+			group->avg_next_update = update_averages(group, now);
+
+		schedule_delayed_work(dwork, nsecs_to_jiffies(
+				group->avg_next_update - now) + 1);
+	}
+
+	mutex_unlock(&group->avgs_lock);
+}
+
+/*
+ * Schedule polling if it's not already scheduled. It's safe to call even from
+ * hotpath because even though kthread_queue_delayed_work takes worker->lock
+ * spinlock that spinlock is never contended due to poll_scheduled atomic
+ * preventing such competition.
+ */
+static void psi_schedule_poll_work(struct psi_group *group, unsigned long delay)
+{
+	struct kthread_worker *kworker;
+
+	/* Do not reschedule if already scheduled */
+	if (atomic_cmpxchg(&group->poll_scheduled, 0, 1) != 0)
+		return;
+
+	rcu_read_lock();
+
+	kworker = rcu_dereference(group->poll_kworker);
+	/*
+	 * kworker might be NULL in case psi_trigger_destroy races with
+	 * psi_task_change (hotpath) which can't use locks
+	 */
+	if (likely(kworker))
+		kthread_queue_delayed_work(kworker, &group->poll_work, delay);
+	else
+		atomic_set(&group->poll_scheduled, 0);
+
+	rcu_read_unlock();
+}
+
+static void psi_poll_work(struct kthread_work *work)
+{
+	struct kthread_delayed_work *dwork;
+	struct psi_group *group;
+	u32 changed_states;
+	u64 now;
+
+	dwork = container_of(work, struct kthread_delayed_work, work);
+	group = container_of(dwork, struct psi_group, poll_work);
+
+	atomic_set(&group->poll_scheduled, 0);
+
+	mutex_lock(&group->trigger_lock);
+
+	now = sched_clock();
+
+	collect_percpu_times(group, PSI_POLL, &changed_states);
 
 	if (changed_states & group->trigger_states) {
 		/* Initialize trigger windows when entering polling mode */
@@ -586,86 +602,25 @@
 		/*
 		 * Keep the monitor active for at least the duration of the
 		 * minimum tracking window as long as monitor states are
-		 * changing. This prevents frequent changes to polling flag
-		 * when system bounces in and out of stall states.
+		 * changing.
 		 */
 		group->polling_until = now +
-			group->trigger_min_period * UPDATES_PER_WINDOW;
+			group->poll_min_period * UPDATES_PER_WINDOW;
 	}
 
-	/* Handle polling flag transitions */
 	if (now > group->polling_until) {
-		if (polling) {
-			group->polling_next_update = ULLONG_MAX;
-			polling = 0;
-			atomic_set(&group->polling, polling);
-			/*
-			 * Memory barrier is needed to order group->polling=0
-			 * write before times[] reads in collect_percpu_times()
-			 * to detect possible race with hotpath that modifies
-			 * times[] before it sets group->polling=1 (see Race #1
-			 * description in the comments at the top).
-			 */
-			smp_mb();
-			/*
-			 * Check if we missed stall recorded by hotpath while
-			 * polling flag was set to 1 causing hotpath to skip
-			 * entering polling mode
-			 */
-			goto retry;
-		}
-	} else {
-		if (!polling) {
-			/*
-			 * This can happen as a fixup in the retry cycle after
-			 * new stall is discovered
-			 */
-			polling = 1;
-			atomic_set(&group->polling, polling);
-		}
-	}
-	/*
-	 * At this point group->polling race with hotpath is resolved and
-	 * we rely on local polling flag ignoring possible further changes
-	 * to group->polling
-	 */
-
-	nonidle = (changed_states & (1 << PSI_NONIDLE));
-	/*
-	 * If there is task activity, periodically fold the per-cpu
-	 * times and feed samples into the running averages. If things
-	 * are idle and there is no data to process, stop the clock.
-	 * Once restarted, we'll catch up the running averages in one
-	 * go - see calc_avgs() and missed_periods.
-	 */
-	if (nonidle && first_pass) {
-		if (now >= group->avg_next_update)
-			group->avg_next_update = update_averages(group, now);
-
-		if (now >= group->polling_next_update) {
-			group->polling_next_update = update_triggers(
-					group, now);
-		}
-	}
-	if ((nonidle && first_pass) || polling) {
-		/* Calculate closest update time */
-		next_update = min(group->polling_next_update,
-					group->avg_next_update);
-		mod_delayed_work(system_wq, dwork, nsecs_to_jiffies(
-				next_update - now) + 1);
-		if (!polling) {
-			/*
-			 * We might have overwritten an immediate update
-			 * scheduled from the hotpath with a longer regular
-			 * update (group->avg_next_update). Execute second pass
-			 * retry to discover that and resume polling.
-			 */
-			first_pass = false;
-			goto retry;
-		}
+		group->polling_next_update = ULLONG_MAX;
+		goto out;
 	}
 
-	mutex_unlock(&group->update_lock);
+	if (now >= group->polling_next_update)
+		group->polling_next_update = update_triggers(group, now);
+
+	psi_schedule_poll_work(group,
+		nsecs_to_jiffies(group->polling_next_update - now) + 1);
+
+out:
+	mutex_unlock(&group->trigger_lock);
 }
 
 static void record_times(struct psi_group_cpu *groupc, int cpu,
@@ -715,7 +670,7 @@
 }
 
 static u32 psi_group_change(struct psi_group *group, int cpu,
-			     unsigned int clear, unsigned int set)
+			    unsigned int clear, unsigned int set)
 {
 	struct psi_group_cpu *groupc;
 	unsigned int t, m;
@@ -819,33 +774,22 @@
 	 */
 	if (unlikely((clear & TSK_RUNNING) &&
 		     (task->flags & PF_WQ_WORKER) &&
-		     wq_worker_last_func(task) == psi_update_work))
+		     wq_worker_last_func(task) == psi_avgs_work))
 		wake_clock = false;
 
 	while ((group = iterate_groups(task, &iter))) {
 		u32 state_mask = psi_group_change(group, cpu, clear, set);
 
 		/*
-		 * Polling flag resets to 0 at the max rate of once per update
-		 * window (at least 500ms interval). smp_wmb is required after
-		 * group->polling 0-to-1 transition to order groupc->times and
-		 * group->polling writes because stall detection logic in the
-		 * slowpath relies on groupc->times changing before
-		 * group->polling. Explicit smp_wmb is missing because cmpxchg()
-		 * implies smp_mb.
+		 * poll_scheduled flag changes at the max rate of once per min
+		 * polling period (50ms), so should not invalidate cache too
+		 * often.
 		 */
-		if ((state_mask & group->trigger_states) &&
-			atomic_cmpxchg(&group->polling, 0, 1) == 0) {
-			/*
-			 * Start polling immediately even if the work is already
-			 * scheduled
-			 */
-			mod_delayed_work(system_wq, &group->clock_work, 1);
-			continue;
-		}
+		if (state_mask & group->trigger_states)
+			psi_schedule_poll_work(group, 1);
 
-		if (wake_clock && !delayed_work_pending(&group->clock_work))
-			schedule_delayed_work(&group->clock_work, PSI_FREQ);
+		if (wake_clock && !delayed_work_pending(&group->avgs_work))
+			schedule_delayed_work(&group->avgs_work, PSI_FREQ);
 	}
 }
 
@@ -942,9 +886,9 @@
 	if (static_branch_likely(&psi_disabled))
 		return;
 
-	cancel_delayed_work_sync(&cgroup->psi.clock_work);
+	cancel_delayed_work_sync(&cgroup->psi.avgs_work);
 	free_percpu(cgroup->psi.pcpu);
-	/* All triggers must be removed by now by psi_trigger_destroy */
+	/* All triggers must be removed by now */
 	WARN_ONCE(cgroup->psi.trigger_states, "psi: trigger leak\n");
 }
 
@@ -964,7 +908,7 @@
 {
 	unsigned int task_flags = 0;
 	struct rq_flags rf;
-	struct rq *rq = NULL;
+	struct rq *rq;
 
 	if (static_branch_likely(&psi_disabled)) {
 		/*
@@ -1006,10 +950,10 @@
 		return -EOPNOTSUPP;
 
 	/* Update averages before reporting them */
-	mutex_lock(&group->update_lock);
-	collect_percpu_times(group, NULL);
+	mutex_lock(&group->avgs_lock);
+	collect_percpu_times(group, PSI_AVGS, NULL);
 	update_averages(group, sched_clock());
-	mutex_unlock(&group->update_lock);
+	mutex_unlock(&group->avgs_lock);
 
 	for (full = 0; full < 2 - (res == PSI_CPU); full++) {
 		unsigned long avg[3];
@@ -1018,7 +962,8 @@
 
 		for (w = 0; w < 3; w++)
 			avg[w] = group->avg[res * 2 + full][w];
-		total = div_u64(group->total[res * 2 + full], NSEC_PER_USEC);
+		total = div_u64(group->total[PSI_AVGS][res * 2 + full],
+				NSEC_PER_USEC);
 
 		seq_printf(m, "%s avg10=%lu.%02lu avg60=%lu.%02lu avg300=%lu.%02lu total=%llu\n",
 			   full ? "full" : "some",
@@ -1105,15 +1050,33 @@
 	init_waitqueue_head(&t->event_wait);
 	kref_init(&t->refcount);
 
-	mutex_lock(&group->update_lock);
+	mutex_lock(&group->trigger_lock);
+
+	if (!rcu_access_pointer(group->poll_kworker)) {
+		struct sched_param param = {
+			.sched_priority = MAX_RT_PRIO - 1,
+		};
+		struct kthread_worker *kworker;
+
+		kworker = kthread_create_worker(0, "psimon");
+		if (IS_ERR(kworker)) {
+			kfree(t);
+			mutex_unlock(&group->trigger_lock);
+			return ERR_CAST(kworker);
+		}
+		sched_setscheduler(kworker->task, SCHED_FIFO, &param);
+		kthread_init_delayed_work(&group->poll_work,
+				psi_poll_work);
+		rcu_assign_pointer(group->poll_kworker, kworker);
+	}
 
 	list_add(&t->node, &group->triggers);
-	group->trigger_min_period = min(group->trigger_min_period,
+	group->poll_min_period = min(group->poll_min_period,
 		div_u64(t->win.size, UPDATES_PER_WINDOW));
 	group->nr_triggers[t->state]++;
 	group->trigger_states |= (1 << t->state);
 
-	mutex_unlock(&group->update_lock);
+	mutex_unlock(&group->trigger_lock);
 
 	return t;
 }
@@ -1122,6 +1085,7 @@
 {
 	struct psi_trigger *t = container_of(ref, struct psi_trigger, refcount);
 	struct psi_group *group = t->group;
+	struct kthread_worker *kworker_to_destroy = NULL;
 
 	if (static_branch_likely(&psi_disabled))
 		return;
@@ -1132,7 +1096,7 @@
 	 */
 	wake_up_interruptible(&t->event_wait);
 
-	mutex_lock(&group->update_lock);
+	mutex_lock(&group->trigger_lock);
 
 	if (!list_empty(&t->node)) {
 		struct psi_trigger *tmp;
@@ -1143,19 +1107,36 @@
 		if (!group->nr_triggers[t->state])
 			group->trigger_states &= ~(1 << t->state);
 		/* reset min update period for the remaining triggers */
-		list_for_each_entry(tmp, &group->triggers, node) {
+		list_for_each_entry(tmp, &group->triggers, node)
 			period = min(period, div_u64(tmp->win.size,
 					UPDATES_PER_WINDOW));
+		group->poll_min_period = period;
+		/* Destroy poll_kworker when the last trigger is destroyed */
+		if (group->trigger_states == 0) {
+			group->polling_until = 0;
+			kworker_to_destroy = rcu_dereference_protected(
+					group->poll_kworker,
+					lockdep_is_held(&group->trigger_lock));
+			rcu_assign_pointer(group->poll_kworker, NULL);
 		}
-		group->trigger_min_period = period;
 	}
 
-	mutex_unlock(&group->update_lock);
+	mutex_unlock(&group->trigger_lock);
+
 	/*
-	 * Wait for RCU to complete its read-side critical section before
-	 * destroying the trigger
+	 * Wait for both *trigger_ptr from psi_trigger_replace and
+	 * poll_kworker RCUs to complete their read-side critical sections
+	 * before destroying the trigger and optionally the poll_kworker
 	 */
 	synchronize_rcu();
+	/*
+	 * Destroy the kworker after releasing trigger_lock to prevent a
+	 * deadlock while waiting for psi_poll_work to acquire trigger_lock
+	 */
+	if (kworker_to_destroy) {
+		kthread_cancel_delayed_work_sync(&group->poll_work);
+		kthread_destroy_worker(kworker_to_destroy);
+	}
 	kfree(t);
 }
 
@@ -1202,7 +1183,7 @@
 }
 
 static ssize_t psi_write(struct file *file, const char __user *user_buf,
-				size_t nbytes, enum psi_res res)
+			 size_t nbytes, enum psi_res res)
 {
 	char buf[32];
 	size_t buf_size;
@@ -1231,20 +1212,20 @@
 	return nbytes;
 }
 
-static ssize_t psi_io_write(struct file *file,
-		const char __user *user_buf, size_t nbytes, loff_t *ppos)
+static ssize_t psi_io_write(struct file *file, const char __user *user_buf,
+			    size_t nbytes, loff_t *ppos)
 {
 	return psi_write(file, user_buf, nbytes, PSI_IO);
 }
 
-static ssize_t psi_memory_write(struct file *file,
-		const char __user *user_buf, size_t nbytes, loff_t *ppos)
+static ssize_t psi_memory_write(struct file *file, const char __user *user_buf,
+				size_t nbytes, loff_t *ppos)
 {
 	return psi_write(file, user_buf, nbytes, PSI_MEM);
 }
 
-static ssize_t psi_cpu_write(struct file *file,
-		const char __user *user_buf, size_t nbytes, loff_t *ppos)
+static ssize_t psi_cpu_write(struct file *file, const char __user *user_buf,
+			     size_t nbytes, loff_t *ppos)
 {
 	return psi_write(file, user_buf, nbytes, PSI_CPU);
 }