Snap for 11916023 from 893827fb61d05dc918c06dbdfaec3b138aca0d2b to 24Q3-release

Change-Id: I50078f710d318905b90107a7431b0ab59fe51533
diff --git a/Android.bp b/Android.bp
index 9d70642..bce1523 100644
--- a/Android.bp
+++ b/Android.bp
@@ -54,7 +54,9 @@
     static_libs: [
         "libstatslogc",
         "liblmkd_utils",
+        "liburing",
     ],
+    include_dirs: ["bionic/libc/kernel"],
     header_libs: [
         "bpf_headers",
     ],
diff --git a/include/liblmkd_utils.h b/include/liblmkd_utils.h
index 1ca5d89..b219b24 100644
--- a/include/liblmkd_utils.h
+++ b/include/liblmkd_utils.h
@@ -40,6 +40,14 @@
 int lmkd_register_proc(int sock, struct lmk_procprio *params);
 
 /*
+ * Registers a batch of processes with lmkd and sets its oomadj score.
+ * On success returns 0.
+ * On error, -1 is returned.
+ * In the case of error errno is set appropriately.
+ */
+int lmkd_register_procs(int sock, struct lmk_procs_prio* params, const int proc_count);
+
+/*
  * Unregisters a process previously registered with lmkd.
  * On success returns 0.
  * On error, -1 is returned.
@@ -85,6 +93,19 @@
  */
 enum boot_completed_notification_result lmkd_notify_boot_completed(int sock);
 
+enum get_kill_count_err_result {
+    GET_KILL_COUNT_SEND_ERR = -1,
+    GET_KILL_COUNT_RECV_ERR = -2,
+    GET_KILL_COUNT_FORMAT_ERR = -3,
+};
+
+/*
+ * Get the number of kills LMKD has performed.
+ * On success returns number of kills.
+ * On error, get_kill_count_err_result integer value.
+ */
+int lmkd_get_kill_count(int sock, struct lmk_getkillcnt* params);
+
 __END_DECLS
 
 #endif /* _LIBLMKD_UTILS_H_ */
diff --git a/include/lmkd.h b/include/lmkd.h
index f57548a..7922d8c 100644
--- a/include/lmkd.h
+++ b/include/lmkd.h
@@ -38,6 +38,7 @@
     LMK_STAT_KILL_OCCURRED, /* Unsolicited msg to subscribed clients on proc kills for statsd log */
     LMK_START_MONITORING,   /* Start psi monitoring if it was skipped earlier */
     LMK_BOOT_COMPLETED,     /* Notify LMKD boot is completed */
+    LMK_PROCS_PRIO,         /* Register processes and set the same oom_adj_score */
 };
 
 /*
@@ -108,6 +109,8 @@
     int oomadj;
     enum proc_type ptype;
 };
+#define LMK_PROCPRIO_FIELD_COUNT 4
+#define LMK_PROCPRIO_SIZE (LMK_PROCPRIO_FIELD_COUNT * sizeof(int))
 
 /*
  * For LMK_PROCPRIO packet get its payload.
@@ -324,6 +327,54 @@
     params->result = ntohl(packet[1]);
 }
 
+#define PROCS_PRIO_MAX_RECORD_COUNT (CTRL_PACKET_MAX_SIZE / LMK_PROCPRIO_SIZE)
+
+struct lmk_procs_prio {
+    struct lmk_procprio procs[PROCS_PRIO_MAX_RECORD_COUNT];
+};
+
+/*
+ * For LMK_PROCS_PRIO packet get its payload.
+ * Warning: no checks performed, caller should ensure valid parameters.
+ */
+static inline int lmkd_pack_get_procs_prio(LMKD_CTRL_PACKET packet, struct lmk_procs_prio* params,
+                                           const int field_count) {
+    if (field_count < LMK_PROCPRIO_FIELD_COUNT || (field_count % LMK_PROCPRIO_FIELD_COUNT) != 0)
+        return -1;
+    const int procs_count = (field_count / LMK_PROCPRIO_FIELD_COUNT);
+
+    /* Start packet at 1 since 0 is cmd type */
+    int packetIdx = 1;
+    for (int procs_idx = 0; procs_idx < procs_count; procs_idx++) {
+        params->procs[procs_idx].pid = (pid_t)ntohl(packet[packetIdx++]);
+        params->procs[procs_idx].uid = (uid_t)ntohl(packet[packetIdx++]);
+        params->procs[procs_idx].oomadj = ntohl(packet[packetIdx++]);
+        params->procs[procs_idx].ptype = (enum proc_type)ntohl(packet[packetIdx++]);
+    }
+
+    return procs_count;
+}
+
+/*
+ * Prepare LMK_PROCS_PRIO packet and return packet size in bytes.
+ * Warning: no checks performed, caller should ensure valid parameters.
+ */
+static inline size_t lmkd_pack_set_procs_prio(LMKD_CTRL_PACKET packet,
+                                              struct lmk_procs_prio* params,
+                                              const int procs_count) {
+    packet[0] = htonl(LMK_PROCS_PRIO);
+    int packetIdx = 1;
+
+    for (int i = 0; i < procs_count; i++) {
+        packet[packetIdx++] = htonl(params->procs[i].pid);
+        packet[packetIdx++] = htonl(params->procs[i].uid);
+        packet[packetIdx++] = htonl(params->procs[i].oomadj);
+        packet[packetIdx++] = htonl((int)params->procs[i].ptype);
+    }
+
+    return packetIdx * sizeof(int);
+}
+
 __END_DECLS
 
 #endif /* _LMKD_H_ */
diff --git a/liblmkd_utils.cpp b/liblmkd_utils.cpp
index f517004..4b383c6 100644
--- a/liblmkd_utils.cpp
+++ b/liblmkd_utils.cpp
@@ -43,6 +43,17 @@
     return (ret < 0) ? -1 : 0;
 }
 
+int lmkd_register_procs(int sock, struct lmk_procs_prio* params, const int proc_count) {
+    LMKD_CTRL_PACKET packet;
+    size_t size;
+    int ret;
+
+    size = lmkd_pack_set_procs_prio(packet, params, proc_count);
+    ret = TEMP_FAILURE_RETRY(write(sock, packet, size));
+
+    return (ret < 0) ? -1 : 0;
+}
+
 int lmkd_unregister_proc(int sock, struct lmk_procremove *params) {
     LMKD_CTRL_PACKET packet;
     size_t size;
@@ -118,6 +129,27 @@
     return res;
 }
 
+int lmkd_get_kill_count(int sock, struct lmk_getkillcnt* params) {
+    LMKD_CTRL_PACKET packet;
+    int size;
+
+    size = lmkd_pack_set_getkillcnt(packet, params);
+    if (TEMP_FAILURE_RETRY(write(sock, packet, size)) < 0) {
+        return (int)GET_KILL_COUNT_SEND_ERR;
+    }
+
+    size = TEMP_FAILURE_RETRY(read(sock, packet, CTRL_PACKET_MAX_SIZE));
+    if (size < 0) {
+        return (int)GET_KILL_COUNT_RECV_ERR;
+    }
+
+    if (size != 2 * sizeof(int) || lmkd_pack_get_cmd(packet) != LMK_GETKILLCNT) {
+        return (int)GET_KILL_COUNT_FORMAT_ERR;
+    }
+
+    return packet[1];
+}
+
 int create_memcg(uid_t uid, pid_t pid) {
     return createProcessGroup(uid, pid, true) == 0 ? 0 : -1;
 }
diff --git a/lmkd.cpp b/lmkd.cpp
index 06ae9ce..2a0f69b 100644
--- a/lmkd.cpp
+++ b/lmkd.cpp
@@ -40,10 +40,12 @@
 #include <shared_mutex>
 #include <vector>
 
+#include <bpf/KernelUtils.h>
 #include <bpf/WaitForProgsLoaded.h>
 #include <cutils/properties.h>
 #include <cutils/sockets.h>
 #include <liblmkd_utils.h>
+#include <liburing.h>
 #include <lmkd.h>
 #include <lmkd_hooks.h>
 #include <log/log.h>
@@ -206,6 +208,11 @@
 static struct timespec direct_reclaim_start_tm;
 static struct timespec kswapd_start_tm;
 
+/* io_uring for LMK_PROCS_PRIO */
+static struct io_uring lmk_io_uring_ring;
+/* IO_URING_OP_READ/WRITE opcodes were introduced only on 5.6 kernel */
+static const bool isIoUringSupported = android::bpf::isAtLeastKernelVersion(5, 6, 0);
+
 static int level_oomadj[VMPRESS_LEVEL_COUNT];
 static int mpevfd[VMPRESS_LEVEL_COUNT] = { -1, -1, -1 };
 static bool pidfd_supported;
@@ -1219,15 +1226,12 @@
     }
 }
 
-static void cmd_procprio(LMKD_CTRL_PACKET packet, int field_count, struct ucred* cred) {
+static void apply_proc_prio(const struct lmk_procprio& params, struct ucred* cred) {
     char path[PROCFS_PATH_MAX];
     char val[20];
-    struct lmk_procprio params;
     int64_t tgid;
     char buf[pagesize];
 
-    lmkd_pack_get_procprio(packet, field_count, &params);
-
     if (params.oomadj < OOM_SCORE_ADJ_MIN || params.oomadj > OOM_SCORE_ADJ_MAX) {
         ALOGE("Invalid PROCPRIO oomadj argument %d", params.oomadj);
         return;
@@ -1268,6 +1272,13 @@
     register_oom_adj_proc(params, cred);
 }
 
+static void cmd_procprio(LMKD_CTRL_PACKET packet, int field_count, struct ucred* cred) {
+    struct lmk_procprio proc_prio;
+
+    lmkd_pack_get_procprio(packet, field_count, &proc_prio);
+    apply_proc_prio(proc_prio, cred);
+}
+
 static void cmd_procremove(LMKD_CTRL_PACKET packet, struct ucred *cred) {
     struct lmk_procremove params;
     struct proc *procp;
@@ -1475,6 +1486,196 @@
     }
 }
 
+static void handle_io_uring_procs_prio(const struct lmk_procs_prio& params, const int procs_count,
+                                       struct ucred* cred) {
+    struct io_uring_sqe* sqe;
+    struct io_uring_cqe* cqe;
+    int fds[PROCS_PRIO_MAX_RECORD_COUNT];
+    char buffers[PROCS_PRIO_MAX_RECORD_COUNT]
+                [256]; /* Reading proc/stat and write to proc/oom_score_adj */
+    char path[PROCFS_PATH_MAX];
+    char val[20];
+    int64_t tgid;
+    int ret;
+    int num_requests = 0;
+
+    ret = io_uring_queue_init(PROCS_PRIO_MAX_RECORD_COUNT, &lmk_io_uring_ring, 0);
+    if (ret) {
+        ALOGE("LMK_PROCS_PRIO failed to setup io_uring ring: %s", strerror(-ret));
+        return;
+    }
+
+    std::fill_n(fds, PROCS_PRIO_MAX_RECORD_COUNT, -1);
+    for (int i = 0; i < procs_count; i++) {
+        if (params.procs[i].oomadj < OOM_SCORE_ADJ_MIN ||
+            params.procs[i].oomadj > OOM_SCORE_ADJ_MAX)
+            ALOGW("Skipping invalid PROCS_PRIO oomadj=%d for pid=%d", params.procs[i].oomadj,
+                  params.procs[i].pid);
+        else if (params.procs[i].ptype < PROC_TYPE_FIRST ||
+                 params.procs[i].ptype >= PROC_TYPE_COUNT)
+            ALOGW("Skipping invalid PROCS_PRIO pid=%d for invalid process type arg %d",
+                  params.procs[i].pid, params.procs[i].ptype);
+        else {
+            snprintf(path, PROCFS_PATH_MAX, "/proc/%d/status", params.procs[i].pid);
+            fds[i] = open(path, O_RDONLY | O_CLOEXEC);
+            if (fds[i] < 0) continue;
+
+            sqe = io_uring_get_sqe(&lmk_io_uring_ring);
+            if (!sqe) {
+                ALOGE("LMK_PROCS_PRIO skipping pid (%d), failed to get SQE for read proc status",
+                      params.procs[i].pid);
+                close(fds[i]);
+                fds[i] = -1;
+                continue;
+            }
+
+            io_uring_prep_read(sqe, fds[i], &buffers[i], sizeof(buffers[i]), 0);
+            sqe->user_data = i;
+            num_requests++;
+        }
+    }
+
+    if (num_requests == 0) {
+        ALOGW("LMK_PROCS_PRIO has no read proc status requests to process");
+        goto err;
+    }
+
+    ret = io_uring_submit(&lmk_io_uring_ring);
+    if (ret <= 0 || ret != num_requests) {
+        ALOGE("Error submitting read processes' status to SQE: %s", strerror(ret));
+        goto err;
+    }
+
+    for (int i = 0; i < num_requests; i++) {
+        ret = TEMP_FAILURE_RETRY(io_uring_wait_cqe(&lmk_io_uring_ring, &cqe));
+        if (ret < 0 || !cqe) {
+            ALOGE("Failed to get CQE, in LMK_PROCS_PRIO, for read batching: %s", strerror(-ret));
+            goto err;
+        }
+        if (cqe->res < 0) {
+            ALOGE("Error in LMK_PROCS_PRIO for async proc status read operation: %s",
+                  strerror(-cqe->res));
+            continue;
+        }
+        if (cqe->user_data < 0 || static_cast<int>(cqe->user_data) > procs_count) {
+            ALOGE("Invalid LMK_PROCS_PRIO CQE read data: %llu", cqe->user_data);
+            continue;
+        }
+
+        const int procs_idx = cqe->user_data;
+        close(fds[procs_idx]);
+        fds[procs_idx] = -1;
+        io_uring_cqe_seen(&lmk_io_uring_ring, cqe);
+
+        if (parse_status_tag(buffers[procs_idx], PROC_STATUS_TGID_FIELD, &tgid) &&
+            tgid != params.procs[procs_idx].pid) {
+            ALOGE("Attempt to register a task that is not a thread group leader "
+                  "(tid %d, tgid %" PRId64 ")",
+                  params.procs[procs_idx].pid, tgid);
+            continue;
+        }
+
+        /* Open write file to prepare for write batch */
+        snprintf(path, sizeof(path), "/proc/%d/oom_score_adj", params.procs[procs_idx].pid);
+        fds[procs_idx] = open(path, O_WRONLY | O_CLOEXEC);
+        if (fds[procs_idx] < 0) {
+            ALOGW("Failed to open %s; errno=%d: process %d might have been killed, skipping for "
+                  "LMK_PROCS_PRIO",
+                  path, errno, params.procs[procs_idx].pid);
+            continue;
+        }
+    }
+
+    /* Prepare to write the new OOM score */
+    num_requests = 0;
+    for (int i = 0; i < procs_count; i++) {
+        if (fds[i] < 0) continue;
+
+        /* gid containing AID_READPROC required */
+        /* CAP_SYS_RESOURCE required */
+        /* CAP_DAC_OVERRIDE required */
+        snprintf(buffers[i], sizeof(buffers[i]), "%d", params.procs[i].oomadj);
+        sqe = io_uring_get_sqe(&lmk_io_uring_ring);
+        if (!sqe) {
+            ALOGE("LMK_PROCS_PRIO skipping pid (%d), failed to get SQE for write",
+                  params.procs[i].pid);
+            close(fds[i]);
+            fds[i] = -1;
+            continue;
+        }
+        io_uring_prep_write(sqe, fds[i], &buffers[i], sizeof(buffers[i]), 0);
+        sqe->user_data = i;
+        num_requests++;
+    }
+
+    if (num_requests == 0) {
+        ALOGW("LMK_PROCS_PRIO has no write proc oomadj requests to process");
+        goto err;
+    }
+
+    ret = io_uring_submit(&lmk_io_uring_ring);
+    if (ret <= 0 || ret != num_requests) {
+        ALOGE("Error submitting write data to sqe: %s", strerror(ret));
+        goto err;
+    }
+
+    /* Handle async write completions for proc/<pid>/oom_score_adj */
+    for (int i = 0; i < num_requests; i++) {
+        ret = TEMP_FAILURE_RETRY(io_uring_wait_cqe(&lmk_io_uring_ring, &cqe));
+        if (ret < 0 || !cqe) {
+            ALOGE("Failed to get CQE, in LMK_PROCS_PRIO, for write batching: %s", strerror(-ret));
+            goto err;
+        }
+        if (cqe->res < 0) {
+            ALOGE("Error in LMK_PROCS_PRIO for async proc status read operation: %s",
+                  strerror(-cqe->res));
+            continue;
+        }
+        if (cqe->user_data < 0 || static_cast<int>(cqe->user_data) > procs_count) {
+            ALOGE("Invalid LMK_PROCS_PRIO CQE read data: %llu", cqe->user_data);
+            continue;
+        }
+
+        const int procs_idx = cqe->user_data;
+        close(fds[procs_idx]);
+        fds[procs_idx] = -1;
+        io_uring_cqe_seen(&lmk_io_uring_ring, cqe);
+
+        if (use_inkernel_interface) {
+            stats_store_taskname(params.procs[procs_idx].pid,
+                                 proc_get_name(params.procs[procs_idx].pid, path, sizeof(path)));
+            continue;
+        }
+
+        register_oom_adj_proc(params.procs[procs_idx], cred);
+    }
+
+    io_uring_queue_exit(&lmk_io_uring_ring);
+    return;
+
+err:
+    for (int fd : fds)
+        if (fd >= 0) close(fd);
+    io_uring_queue_exit(&lmk_io_uring_ring);
+    return;
+}
+
+static void cmd_procs_prio(LMKD_CTRL_PACKET packet, const int field_count, struct ucred* cred) {
+    struct lmk_procs_prio params;
+
+    const int procs_count = lmkd_pack_get_procs_prio(packet, &params, field_count);
+    if (procs_count < 0) {
+        ALOGE("LMK_PROCS_PRIO received invalid packet format");
+        return;
+    }
+
+    if (isIoUringSupported) {
+        handle_io_uring_procs_prio(params, procs_count, cred);
+    } else {
+        for (int i = 0; i < procs_count; i++) apply_proc_prio(params.procs[i], cred);
+    }
+}
+
 static void ctrl_command_handler(int dsock_idx) {
     LMKD_CTRL_PACKET packet;
     struct ucred cred;
@@ -1623,6 +1824,9 @@
             ALOGE("Failed to report boot-completed operation results");
         }
         break;
+    case LMK_PROCS_PRIO:
+        cmd_procs_prio(packet, nargs, &cred);
+        break;
     default:
         ALOGE("Received unknown command code %d", cmd);
         return;
@@ -2919,8 +3123,7 @@
     } else if (reclaim == DIRECT_RECLAIM && direct_reclaim_threshold_ms > 0 &&
                direct_reclaim_duration_ms > direct_reclaim_threshold_ms) {
         kill_reason = DIRECT_RECL_STUCK;
-        snprintf(kill_desc, sizeof(kill_desc),
-                 "device is stuck in direct reclaim (%" PRId64 "ms > %dms)",
+        snprintf(kill_desc, sizeof(kill_desc), "device is stuck in direct reclaim (%ldms > %dms)",
                  direct_reclaim_duration_ms, direct_reclaim_threshold_ms);
     } else if (check_filecache) {
         int64_t file_lru_kb = (vs.field.nr_inactive_file + vs.field.nr_active_file) * page_k;
diff --git a/tests/lmkd_tests.cpp b/tests/lmkd_tests.cpp
index 9ad3d3b..9b70d38 100644
--- a/tests/lmkd_tests.cpp
+++ b/tests/lmkd_tests.cpp
@@ -26,6 +26,7 @@
 #include <liblmkd_utils.h>
 #include <log/log_properties.h>
 #include <private/android_filesystem_config.h>
+#include <stdlib.h>
 
 using namespace android::base;
 
@@ -113,6 +114,16 @@
         }
     }
 
+    void SendProcsPrioRequest(struct lmk_procs_prio procs_prio_request, int procs_count) {
+        ASSERT_FALSE(lmkd_register_procs(sock, &procs_prio_request, procs_count) < 0)
+                << "Failed to communicate with lmkd, err=" << strerror(errno);
+    }
+
+    void SendGetKillCountRequest(struct lmk_getkillcnt* get_kill_cnt_request) {
+        ASSERT_GE(lmkd_get_kill_count(sock, get_kill_cnt_request), 0)
+                << "Failed fetching lmkd kill count";
+    }
+
     static std::string ExecCommand(const std::string& command) {
         FILE* fp = popen(command.c_str(), "r");
         std::string content;
@@ -170,6 +181,8 @@
                reap_pid == pid;
     }
 
+    uid_t getLmkdTestUid() const { return uid; }
+
   private:
     int sock;
     uid_t uid;
@@ -239,6 +252,88 @@
    }
 }
 
+/*
+ * Verify that the `PROCS_PRIO` cmd is able to receive a batch of processes and adjust their
+ * those processes' OOM score.
+ */
+TEST_F(LmkdTest, batch_procs_oom_score_adj) {
+    struct ChildProcessInfo {
+        pid_t pid;
+        int original_oom_score;
+        int req_new_oom_score;
+    };
+
+    struct ChildProcessInfo children_info[PROCS_PRIO_MAX_RECORD_COUNT];
+
+    for (unsigned int i = 0; i < PROCS_PRIO_MAX_RECORD_COUNT; i++) {
+        children_info[i].pid = fork();
+        if (children_info[i].pid < 0) {
+            for (const auto child : children_info)
+                if (child.pid >= 0) kill(child.pid, SIGKILL);
+            FAIL() << "Failed forking process in iteration=" << i;
+        } else if (children_info[i].pid == 0) {
+            /*
+             * Keep the children alive, the parent process will kill it
+             * once we are done with it.
+             */
+            while (true) {
+                sleep(20);
+            }
+        }
+    }
+
+    struct lmk_procs_prio procs_prio_request;
+    const uid_t parent_uid = getLmkdTestUid();
+
+    for (unsigned int i = 0; i < PROCS_PRIO_MAX_RECORD_COUNT; i++) {
+        if (children_info[i].pid < 0) continue;
+
+        const std::string process_oom_path =
+                "proc/" + std::to_string(children_info[i].pid) + "/oom_score_adj";
+        std::string curr_oom_score;
+        if (!ReadFileToString(process_oom_path, &curr_oom_score) || curr_oom_score.empty()) {
+            for (const auto child : children_info)
+                if (child.pid >= 0) kill(child.pid, SIGKILL);
+            FAIL() << "Failed reading original oom score for child process: "
+                   << children_info[i].pid;
+        }
+
+        children_info[i].original_oom_score = atoi(curr_oom_score.c_str());
+        children_info[i].req_new_oom_score =
+                ((unsigned int)children_info[i].original_oom_score != i) ? i : (i + 10);
+        procs_prio_request.procs[i] = {.pid = children_info[i].pid,
+                                       .uid = parent_uid,
+                                       .oomadj = children_info[i].req_new_oom_score,
+                                       .ptype = proc_type::PROC_TYPE_APP};
+    }
+
+    /*
+     * Submit batching, then send a new/different request and wait for LMKD
+     * to respond to it. This ensures that LMKD has finished the batching
+     * request and we can now read/validate the new OOM scores.
+     */
+    SendProcsPrioRequest(procs_prio_request, PROCS_PRIO_MAX_RECORD_COUNT);
+    struct lmk_getkillcnt kill_cnt_req = {.min_oomadj = -1000, .max_oomadj = 1000};
+    SendGetKillCountRequest(&kill_cnt_req);
+
+    for (auto child_info : children_info) {
+        if (child_info.pid < 0) continue;
+        const std::string process_oom_path =
+                "proc/" + std::to_string(child_info.pid) + "/oom_score_adj";
+        std::string curr_oom_score;
+        if (!ReadFileToString(process_oom_path, &curr_oom_score) || curr_oom_score.empty()) {
+            for (const auto child : children_info)
+                if (child.pid >= 0) kill(child.pid, SIGKILL);
+            FAIL() << "Failed reading new oom score for child process: " << child_info.pid;
+        }
+        kill(child_info.pid, SIGKILL);
+
+        const int actual_new_oom_score = atoi(curr_oom_score.c_str());
+        ASSERT_EQ(child_info.req_new_oom_score, actual_new_oom_score)
+                << "Child with pid=" << child_info.pid << " didn't update its OOM score";
+    }
+}
+
 int main(int argc, char** argv) {
     ::testing::InitGoogleTest(&argc, argv);
     InitLogging(argv, StderrLogger);