Persist rate limiter data

- Add proto for rate limiter data
- Add read/write logic and hook up
- Add a test to verify persisting and restoring

Test: new test
Bug: 334086203
Change-Id: Ic567853f7bdc4fdc5e31476c47da2fb740c0e11b
diff --git a/service/Android.bp b/service/Android.bp
index cf33f6b..5baf05b 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -36,6 +36,7 @@
     static_libs: [
         "modules-utils-build",
         "android.os.profiling.flags-aconfig-java",
+        "service-profiling-proto",
     ],
     lint: {
         strict_updatability_linting: true,
@@ -46,3 +47,27 @@
     min_sdk_version: "current",
     installable: true,
 }
+
+java_library {
+    name: "service-profiling-proto",
+    proto: {
+        type: "lite",
+        canonical_path_from_root: false,
+        include_dirs: [
+            "external/protobuf/src",
+            "external/protobuf/java",
+        ],
+    },
+    srcs: [
+        "proto/**/*.proto",
+    ],
+    visibility: [
+        "//packages/modules/Profiling/tests:__subpackages__",
+    ],
+    installable: false,
+    min_sdk_version: "current",
+    sdk_version: "system_server_current",
+    apex_available: [
+        "com.android.profiling",
+    ],
+}
diff --git a/service/java/com/android/os/profiling/ProfilingService.java b/service/java/com/android/os/profiling/ProfilingService.java
index 955293b..89adf75 100644
--- a/service/java/com/android/os/profiling/ProfilingService.java
+++ b/service/java/com/android/os/profiling/ProfilingService.java
@@ -611,7 +611,12 @@
 
     private RateLimiter getRateLimiter() {
         if (mRateLimiter == null) {
-            mRateLimiter = new RateLimiter(mContext);
+            mRateLimiter = new RateLimiter(new RateLimiter.HandlerCallback() {
+                @Override
+                public Handler obtainHandler() {
+                    return getHandler();
+                }
+            });
         }
         return mRateLimiter;
     }
diff --git a/service/java/com/android/os/profiling/RateLimiter.java b/service/java/com/android/os/profiling/RateLimiter.java
index bc25e6b..1c27a2c 100644
--- a/service/java/com/android/os/profiling/RateLimiter.java
+++ b/service/java/com/android/os/profiling/RateLimiter.java
@@ -18,24 +18,38 @@
 
 import android.annotation.IntDef;
 import android.annotation.Nullable;
-import android.content.Context;
 import android.os.Bundle;
+import android.os.Environment;
+import android.os.Handler;
 import android.os.ProfilingManager;
 import android.os.ProfilingResult;
+import android.os.RateLimiterRecordsWrapper;
+import android.util.AtomicFile;
+import android.util.Log;
 import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
 
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.util.ArrayDeque;
 import java.util.Queue;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 public class RateLimiter {
+    private static final String TAG = RateLimiter.class.getSimpleName();
+    private static final boolean DEBUG = false;
 
-    private static final long TIME_1_HOUR_MS = 60 * 60 * 1000;
-    private static final long TIME_24_HOUR_MS = 24 * 60 * 60 * 1000;
-    private static final long TIME_7_DAY_MS = 7 * 24 * 60 * 60 * 1000;
+    private static final String RATE_LIMITER_STORE_DIR = "profiling_rate_limiter_store";
+    private static final String RATE_LIMITER_INFO_FILE = "profiling_rate_limiter_info";
+
+    private static final long TIME_HOUR_MS = 60 * 60 * 1000;
+    private static final long TIME_DAY_MS = 24 * 60 * 60 * 1000;
+    private static final long TIME_WEEK_MS = 7 * 24 * 60 * 60 * 1000;
 
     private static final int DEFAULT_MAX_COST_SYSTEM_HOUR = 2;
     private static final int DEFAULT_MAX_COST_PROCESS_HOUR = 1;
@@ -51,7 +65,6 @@
 
     private final Object mLock = new Object();
 
-    private final Context mContext;
     private final long mPersistToDiskFrequency;
 
     /** To be disabled for testing only. */
@@ -59,19 +72,42 @@
     private boolean mRateLimiterDisabled = false;
 
     /** Collection of run costs and entries from the last hour. */
-    private final EntryGroupWrapper mPastRuns1Hour;
-    /** Collection of run costs and entries from the 24 hours. */
-    private final EntryGroupWrapper mPastRuns24Hour;
-    /** Collection of run costs and entries from the 7 days. */
-    private final EntryGroupWrapper mPastRuns7Day;
+    @VisibleForTesting
+    public final EntryGroupWrapper mPastRunsHour;
+
+    /** Collection of run costs and entries from the last day. */
+    @VisibleForTesting
+    public final EntryGroupWrapper mPastRunsDay;
+
+    /** Collection of run costs and entries from the last week. */
+    @VisibleForTesting
+    public final EntryGroupWrapper mPastRunsWeek;
 
     private final int mCostJavaHeapDump;
     private final int mCostHeapProfile;
     private final int mCostStackSampling;
     private final int mCostSystemTrace;
 
+    private final HandlerCallback mHandlerCallback;
+
+    private Runnable mPersistRunnable = null;
+    private boolean mPersistScheduled = false;
+
     private long mLastPersistedTimestampMs;
 
+    /**
+     * The path to the directory which includes the historical rate limiter data file as specified
+     * in {@link #mPersistFile}.
+     */
+    @VisibleForTesting
+    public File mPersistStoreDir;
+
+    /** The historical rate limiter data file, persisted in the storage. */
+    @VisibleForTesting
+    public File mPersistFile;
+
+    private AtomicBoolean mDataLoaded = new AtomicBoolean();
+
     @IntDef(value = {
         RATE_LIMIT_RESULT_ALLOWED,
         RATE_LIMIT_RESULT_BLOCKED_PROCESS,
@@ -80,27 +116,31 @@
     @Retention(RetentionPolicy.SOURCE)
     @interface RateLimitResult {}
 
-    public RateLimiter(Context context) {
-        mContext = context;
+    /**
+     * @param handlerCallback Callback for rate limiter to obtain a {@link Handler} to schedule
+     *                        work such as persisting to storage.
+     */
+    public RateLimiter(HandlerCallback handlerCallback) {
+        mHandlerCallback = handlerCallback;
 
-        mPastRuns1Hour = new EntryGroupWrapper(
+        mPastRunsHour = new EntryGroupWrapper(
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_SYSTEM_1_HOUR,
                         DEFAULT_MAX_COST_SYSTEM_HOUR),
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_PROCESS_1_HOUR,
                         DEFAULT_MAX_COST_PROCESS_HOUR),
-                TIME_1_HOUR_MS);
-        mPastRuns24Hour = new EntryGroupWrapper(
+                TIME_HOUR_MS);
+        mPastRunsDay = new EntryGroupWrapper(
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_SYSTEM_24_HOUR,
                         DEFAULT_MAX_COST_SYSTEM_DAY),
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_PROCESS_24_HOUR,
                         DEFAULT_MAX_COST_PROCESS_DAY),
-                TIME_24_HOUR_MS);
-        mPastRuns7Day = new EntryGroupWrapper(
+                TIME_DAY_MS);
+        mPastRunsWeek = new EntryGroupWrapper(
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_SYSTEM_7_DAY,
                         DEFAULT_MAX_COST_SYSTEM_WEEK),
                 DeviceConfigHelper.getInt(DeviceConfigHelper.MAX_COST_PROCESS_7_DAY,
                         DEFAULT_MAX_COST_PROCESS_WEEK),
-                TIME_7_DAY_MS);
+                TIME_WEEK_MS);
 
         mCostJavaHeapDump = DeviceConfigHelper.getInt(DeviceConfigHelper.COST_JAVA_HEAP_DUMP,
                 DEFAULT_COST_PER_SESSION);
@@ -115,14 +155,49 @@
                 DeviceConfigHelper.PERSIST_TO_DISK_FREQUENCY_MS, 0);
         mLastPersistedTimestampMs = System.currentTimeMillis();
 
-        loadFromDisk();
-
         // Get initial value for whether rate limiter should be enforcing or if it should always
         // allow profiling requests. This is used for (automated and manual) testing only.
         synchronized (mLock) {
             mRateLimiterDisabled = DeviceConfigHelper.getTestBoolean(
                     DeviceConfigHelper.RATE_LIMITER_DISABLE_PROPERTY, false);
         }
+
+        try {
+            if (setupPersistFiles()) {
+                loadFromDisk();
+            } else {
+                // Directory doesn't exist so file must not exist. Nothing to load.
+                // Mark complete and return.
+                if (DEBUG) {
+                    Log.d(TAG, "Persist file directory does not exist, skipping load from disk.");
+                }
+                mDataLoaded.set(true);
+            }
+        } catch (SecurityException e) {
+            if (DEBUG) Log.d(TAG, "Exception creating directory.", e);
+        } finally {
+            if (!mDataLoaded.get()) {
+                // Loading of persisted data failed for a reason other than the file not existing.
+                // Delete the file, pad history with fake entries to reduce availability,
+                // and then mark complete.
+                if (mPersistFile != null) {
+                    try {
+                        mPersistFile.delete();
+                        if (DEBUG) Log.d(TAG, "Deleted persist file which could not be parsed.");
+                    } catch (SecurityException e) {
+                        if (DEBUG) Log.d(TAG, "Failed to delete persist file", e);
+                    }
+                }
+
+                // TODO: b/335542725 - revisit how to deal with failed load case
+                final long timestamp = System.currentTimeMillis();
+                mPastRunsHour.add(-1 /*fake uid*/, mPastRunsHour.mMaxCost / 2, timestamp);
+                mPastRunsDay.add(-1 /*fake uid*/, mPastRunsDay.mMaxCost / 2, timestamp);
+                mPastRunsWeek.add(-1 /*fake uid*/, mPastRunsWeek.mMaxCost / 2, timestamp);
+
+                mDataLoaded.set(true);
+            }
+        }
     }
 
     public @RateLimitResult int isProfilingRequestAllowed(int uid,
@@ -130,21 +205,27 @@
         synchronized (mLock) {
             if (mRateLimiterDisabled) {
                 // Rate limiter is disabled for testing, approve request and don't store cost.
+                Log.w(TAG, "Rate limiter disabled, request allowed.");
                 return RATE_LIMIT_RESULT_ALLOWED;
             }
+            if (!mDataLoaded.get()) {
+                // Requests before rate limiter data are all rejected.
+                Log.e(TAG, "Data loading in progress, request denied.");
+                return RATE_LIMIT_RESULT_BLOCKED_SYSTEM;
+            }
             final int cost = getCostForProfiling(profilingType);
             final long currentTimeMillis = System.currentTimeMillis();
-            int status = mPastRuns1Hour.isProfilingAllowed(uid, cost, currentTimeMillis);
+            int status = mPastRunsHour.isProfilingAllowed(uid, cost, currentTimeMillis);
             if (status == RATE_LIMIT_RESULT_ALLOWED) {
-                status = mPastRuns24Hour.isProfilingAllowed(uid, cost, currentTimeMillis);
+                status = mPastRunsDay.isProfilingAllowed(uid, cost, currentTimeMillis);
             }
             if (status == RATE_LIMIT_RESULT_ALLOWED) {
-                status = mPastRuns7Day.isProfilingAllowed(uid, cost, currentTimeMillis);
+                status = mPastRunsWeek.isProfilingAllowed(uid, cost, currentTimeMillis);
             }
             if (status == RATE_LIMIT_RESULT_ALLOWED) {
-                mPastRuns1Hour.add(uid, cost, currentTimeMillis);
-                mPastRuns24Hour.add(uid, cost, currentTimeMillis);
-                mPastRuns7Day.add(uid, cost, currentTimeMillis);
+                mPastRunsHour.add(uid, cost, currentTimeMillis);
+                mPastRunsDay.add(uid, cost, currentTimeMillis);
+                mPastRunsWeek.add(uid, cost, currentTimeMillis);
                 maybePersistToDisk();
                 return RATE_LIMIT_RESULT_ALLOWED;
             }
@@ -167,22 +248,159 @@
         }
     }
 
+    /**
+     * This method is meant to be called every time a profiling record is added to the history.
+     * - If persist frequency is set to 0, it will immediately persist the records to disk.
+     * - If a persist is already scheduled, it will do nothing.
+     * - If the last records persist occurred longer ago than the persist frequency, it will
+     *      persist immediately.
+     * - In all other cases, it will schedule a persist event at persist frequency after the last
+     *      persist event.
+     */
     void maybePersistToDisk() {
+        if (mPersistScheduled) {
+            // We're already waiting on a scheduled persist job, do nothing.
+            return;
+        }
+
         if (mPersistToDiskFrequency == 0
-                || System.currentTimeMillis() - mLastPersistedTimestampMs
-                >= mPersistToDiskFrequency) {
+                || (System.currentTimeMillis() - mLastPersistedTimestampMs
+                        >= mPersistToDiskFrequency)) {
+            // If persist frequency is 0 or if it's already been longer than persist frequency since
+            // the last persist then persist immediately.
             persistToDisk();
         } else {
-            // TODO: queue persist job b/293957254
+            // Schedule the persist job.
+            if (mPersistRunnable == null) {
+                mPersistRunnable = new Runnable() {
+                    @Override
+                    public void run() {
+                        persistToDisk();
+                        mPersistScheduled = false;
+                    }
+                };
+            }
+            mPersistScheduled = true;
+            long persistDelay = mLastPersistedTimestampMs + mPersistToDiskFrequency
+                    - System.currentTimeMillis();
+            mHandlerCallback.obtainHandler().postDelayed(mPersistRunnable, persistDelay);
         }
     }
 
-    void persistToDisk() {
-        // TODO: b/293957254
+    /**
+     * Clean up records and persist to disk.
+     *
+     * Skips if {@link mPersistFile} is not accessible to write to.
+     */
+    public void persistToDisk() {
+        // Check if file exists
+        try {
+            if (mPersistFile == null) {
+                // Try again to create the necessary files.
+                if (!setupPersistFiles()) {
+                    // No file, nowhere to save.
+                    if (DEBUG) Log.d(TAG, "Failed setting up persist files so nowhere to save to.");
+                    return;
+                }
+            }
+
+            if (!mPersistFile.exists()) {
+                // File doesn't exist, try to create it.
+                mPersistFile.createNewFile();
+            }
+        } catch (Exception e) {
+            if (DEBUG) Log.d(TAG, "Exception accessing persisted records", e);
+            return;
+        }
+
+        // Clean up old records to reduce extraneous writes
+        mPastRunsWeek.cleanUpOldRecords();
+
+        // Generate proto for records. We only persist week records as this contains all smaller
+        // time ranges.
+        RateLimiterRecordsWrapper outerWrapper = RateLimiterRecordsWrapper.newBuilder()
+                .setRecords(mPastRunsWeek.toProto())
+                .build();
+
+        // Write to disk
+        byte[] protoBytes = outerWrapper.toByteArray();
+        AtomicFile persistFile = new AtomicFile(mPersistFile);
+        FileOutputStream out = null;
+        try {
+            out = persistFile.startWrite();
+            out.write(protoBytes);
+            persistFile.finishWrite(out);
+        } catch (IOException e) {
+            if (DEBUG) Log.d(TAG, "Exception writing records", e);
+            persistFile.failWrite(out);
+        }
     }
 
-    void loadFromDisk() {
-        // TODO: b/293957254
+    /**
+     * Load initial records data from disk.
+     *
+     * If the file doesn't exist or if it hits an error loading then it marks complete and exits.
+     */
+    public void loadFromDisk() {
+        // Check if file exists
+        try {
+            if (mPersistFile == null || !mPersistFile.exists()) {
+                // No file, nothing to load. Mark complete and return.
+                if (DEBUG) Log.d(TAG, "Persist file does not exist, skipping load from disk.");
+                mDataLoaded.set(true);
+                return;
+            }
+        } catch (SecurityException e) {
+            // Can't access file.
+            if (DEBUG) Log.d(TAG, "Exception accessing persist file", e);
+            return;
+        }
+
+        // Read the file
+        AtomicFile persistFile = new AtomicFile(mPersistFile);
+        byte[] bytes = null;
+        try {
+            bytes = persistFile.readFully();
+        } catch (IOException e) {
+            if (DEBUG) Log.d(TAG, "Exception reading persist file", e);
+        }
+        if (bytes == null) {
+            // Failed to read file.
+            if (DEBUG) Log.d(TAG, "Persist file loaded bytes empty.");
+            return;
+        }
+
+        // Parse file bytes to proto
+        RateLimiterRecordsWrapper outerWrapper;
+        try {
+            outerWrapper = RateLimiterRecordsWrapper.parseFrom(bytes);
+        } catch (Exception e) {
+            // Failed to parse.
+            if (DEBUG) Log.d(TAG, "Error parsing proto from persisted bytes", e);
+            return;
+        }
+
+        // Populate in memory records stores
+        RateLimiterRecordsWrapper.EntryGroupWrapper weekGroupWrapper =
+                outerWrapper.getRecords();
+        final long currentTimeMillis = System.currentTimeMillis();
+        for (int i = 0; i < weekGroupWrapper.getEntriesCount(); i++) {
+            RateLimiterRecordsWrapper.EntryGroupWrapper.Entry entry =
+                    weekGroupWrapper.getEntries(i);
+            // Check if this timestamp fits the time range for each records collection.
+            if (entry.getTimestamp() > currentTimeMillis - mPastRunsHour.mTimeRangeMs) {
+                mPastRunsHour.add(entry.getUid(), entry.getCost(), entry.getTimestamp());
+            }
+            if (entry.getTimestamp() > currentTimeMillis - mPastRunsDay.mTimeRangeMs) {
+                mPastRunsDay.add(entry.getUid(), entry.getCost(), entry.getTimestamp());
+            }
+            if (entry.getTimestamp() > currentTimeMillis - mPastRunsWeek.mTimeRangeMs) {
+                mPastRunsWeek.add(entry.getUid(), entry.getCost(), entry.getTimestamp());
+            }
+        }
+
+        // Set loaded to api usage can start
+        mDataLoaded.set(true);
     }
 
     /** Update the disable rate limiter flag. */
@@ -203,8 +421,35 @@
         }
     }
 
-    static final class EntryGroupWrapper {
+    private boolean setupPersistFiles() throws SecurityException {
+        File dataDir = Environment.getDataDirectory();
+        File systemDir = new File(dataDir, "system");
+        mPersistStoreDir = new File(systemDir, RATE_LIMITER_STORE_DIR);
+        if (createDir(mPersistStoreDir)) {
+            mPersistFile = new File(mPersistStoreDir, RATE_LIMITER_INFO_FILE);
+            return true;
+        }
+        return false;
+    }
+
+    private static boolean createDir(File dir) throws SecurityException {
+        if (dir.mkdir()) {
+            return true;
+        }
+
+        if (dir.exists()) {
+            return dir.isDirectory();
+        }
+
+        return false;
+    }
+
+    public static final class EntryGroupWrapper {
+        private final Object mLock = new Object();
+
+        @GuardedBy("mLock")
         final Queue<CollectionEntry> mEntries;
+
         int mTotalCost;
         // uid indexed
         final SparseIntArray mPerUidCost;
@@ -213,22 +458,34 @@
         final long mTimeRangeMs;
 
         EntryGroupWrapper(final int maxCost, final int maxPerUidCost, final long timeRangeMs) {
-            mMaxCost = maxCost;
-            mMaxCostPerUid = maxPerUidCost;
-            mTimeRangeMs = timeRangeMs;
-            mEntries = new ArrayDeque<>();
-            mPerUidCost = new SparseIntArray();
+            synchronized (mLock) {
+                mMaxCost = maxCost;
+                mMaxCostPerUid = maxPerUidCost;
+                mTimeRangeMs = timeRangeMs;
+                mEntries = new ArrayDeque<>();
+                mPerUidCost = new SparseIntArray();
+            }
         }
 
-        void add(final int uid, final int cost, final long timestamp) {
-            mTotalCost += cost;
-            final int index = mPerUidCost.indexOfKey(uid);
-            if (index < 0) {
-                mPerUidCost.put(uid, cost);
-            } else {
-                mPerUidCost.put(uid, mPerUidCost.valueAt(index) + cost);
+        /** Add a record and update cached costs accordingly. */
+        public void add(final int uid, final int cost, final long timestamp) {
+            synchronized (mLock) {
+                mTotalCost += cost;
+                final int index = mPerUidCost.indexOfKey(uid);
+                if (index < 0) {
+                    mPerUidCost.put(uid, cost);
+                } else {
+                    mPerUidCost.put(uid, mPerUidCost.valueAt(index) + cost);
+                }
+                mEntries.offer(new CollectionEntry(uid, cost, timestamp));
             }
-            mEntries.offer(new CollectionEntry(uid, cost, timestamp));
+        }
+
+        /**
+         * Clean up the queue by removing entries that are too old.
+         */
+        public void cleanUpOldRecords() {
+            removeOlderThan(System.currentTimeMillis() - mTimeRangeMs);
         }
 
         /**
@@ -236,19 +493,23 @@
          *
          * @param olderThanTimestamp timestamp to remove record which are older than.
          */
-        void removeOlderThan(final long olderThanTimestamp) {
-            while (mEntries.peek() != null && mEntries.peek().mTimestamp <= olderThanTimestamp) {
-                final CollectionEntry entry = mEntries.poll();
-                if (entry == null) {
-                    return;
-                }
-                mTotalCost -= entry.mCost;
-                if (mTotalCost < 0) {
-                    mTotalCost = 0;
-                }
-                final int index = mPerUidCost.indexOfKey(entry.mUid);
-                if (index >= 0) {
-                    mPerUidCost.setValueAt(index, mPerUidCost.valueAt(index) - entry.mCost);
+        public void removeOlderThan(final long olderThanTimestamp) {
+            synchronized (mLock) {
+                while (mEntries.peek() != null
+                        && mEntries.peek().mTimestamp <= olderThanTimestamp) {
+                    final CollectionEntry entry = mEntries.poll();
+                    if (entry == null) {
+                        return;
+                    }
+                    mTotalCost -= entry.mCost;
+                    if (mTotalCost < 0) {
+                        mTotalCost = 0;
+                    }
+                    final int index = mPerUidCost.indexOfKey(entry.mUid);
+                    if (index >= 0) {
+                        mPerUidCost.setValueAt(index, Math.max(0,
+                                mPerUidCost.valueAt(index) - entry.mCost));
+                    }
                 }
             }
         }
@@ -265,26 +526,77 @@
          */
         @RateLimitResult int isProfilingAllowed(final int uid, final int cost,
                 final long currentTimeMillis) {
-            removeOlderThan(currentTimeMillis - mTimeRangeMs);
-            if (mTotalCost + cost > mMaxCost) {
-                return RATE_LIMIT_RESULT_BLOCKED_SYSTEM;
+            synchronized (mLock) {
+                removeOlderThan(currentTimeMillis - mTimeRangeMs);
+                if (mTotalCost + cost > mMaxCost) {
+                    return RATE_LIMIT_RESULT_BLOCKED_SYSTEM;
+                }
+                final int index = mPerUidCost.indexOfKey(uid);
+                return ((index < 0 ? 0 : mPerUidCost.valueAt(index)) + cost < mMaxCostPerUid)
+                        ? RATE_LIMIT_RESULT_ALLOWED : RATE_LIMIT_RESULT_BLOCKED_PROCESS;
             }
-            final int index = mPerUidCost.indexOfKey(uid);
-            return ((index < 0 ? 0 : mPerUidCost.valueAt(index)) + cost < mMaxCostPerUid)
-                    ? RATE_LIMIT_RESULT_ALLOWED : RATE_LIMIT_RESULT_BLOCKED_PROCESS;
+        }
+
+        RateLimiterRecordsWrapper.EntryGroupWrapper toProto() {
+            synchronized (mLock) {
+                RateLimiterRecordsWrapper.EntryGroupWrapper.Builder builder =
+                        RateLimiterRecordsWrapper.EntryGroupWrapper.newBuilder();
+
+                CollectionEntry[] entries = mEntries.toArray(new CollectionEntry[mEntries.size()]);
+                for (int i = 0; i < entries.length; i++) {
+                    builder.addEntries(entries[i].toProto());
+                }
+
+                return builder.build();
+            }
+        }
+
+        void populateFromProto(RateLimiterRecordsWrapper.EntryGroupWrapper group) {
+            synchronized (mLock) {
+                final long currentTimeMillis = System.currentTimeMillis();
+                for (int i = 0; i < group.getEntriesCount(); i++) {
+                    RateLimiterRecordsWrapper.EntryGroupWrapper.Entry entry = group.getEntries(i);
+                    if (entry.getTimestamp() > currentTimeMillis - mTimeRangeMs) {
+                        add(entry.getUid(), entry.getCost(), entry.getTimestamp());
+                    }
+                }
+            }
+        }
+
+        /** Get a copied array of the backing data. */
+        public CollectionEntry[] getEntriesCopy() {
+            synchronized (mLock) {
+                CollectionEntry[] array = new CollectionEntry[mEntries.size()];
+                array = mEntries.toArray(array);
+                return array.clone();
+            }
         }
     }
 
-    static final class CollectionEntry {
-        final int mUid;
-        final int mCost;
-        final Long mTimestamp;
+    public static final class CollectionEntry {
+        public final int mUid;
+        public final int mCost;
+        public final Long mTimestamp;
 
         CollectionEntry(final int uid, final int cost, final Long timestamp) {
             mUid = uid;
             mCost = cost;
             mTimestamp = timestamp;
         }
+
+        RateLimiterRecordsWrapper.EntryGroupWrapper.Entry toProto() {
+            return RateLimiterRecordsWrapper.EntryGroupWrapper.Entry.newBuilder()
+                    .setUid(mUid)
+                    .setCost(mCost)
+                    .setTimestamp(mTimestamp)
+                    .build();
+        }
+    }
+
+    public interface HandlerCallback {
+
+        /** Obtain a handler to schedule persisting records to disk. */
+        Handler obtainHandler();
     }
 }
 
diff --git a/service/proto/android/os/ratelimiter.proto b/service/proto/android/os/ratelimiter.proto
new file mode 100644
index 0000000..9275d94
--- /dev/null
+++ b/service/proto/android/os/ratelimiter.proto
@@ -0,0 +1,18 @@
+syntax = "proto2";
+
+package android.os;
+
+option java_outer_classname = "RateLimiterProto";
+option java_multiple_files = true;
+
+message RateLimiterRecordsWrapper {
+  message EntryGroupWrapper {
+    message Entry {
+      required int32 uid = 1;
+      required int32 cost = 2;
+      required int64 timestamp = 3;
+    }
+    repeated Entry entries = 1;
+  }
+  required EntryGroupWrapper records = 1;
+}
diff --git a/tests/cts/src/android/profiling/cts/ProfilingServiceTests.java b/tests/cts/src/android/profiling/cts/ProfilingServiceTests.java
index 9914ba0..f134fce 100644
--- a/tests/cts/src/android/profiling/cts/ProfilingServiceTests.java
+++ b/tests/cts/src/android/profiling/cts/ProfilingServiceTests.java
@@ -27,13 +27,16 @@
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.spy;
 
+import android.app.Instrumentation;
 import android.content.Context;
 import android.content.pm.PackageManager;
 import android.os.Binder;
+import android.os.Handler;
 import android.os.IProfilingResultCallback;
 import android.os.ParcelFileDescriptor;
 import android.os.ProfilingManager;
 import android.os.ProfilingResult;
+import android.os.profiling.DeviceConfigHelper;
 import android.os.profiling.ProfilingService;
 import android.os.profiling.RateLimiter;
 import android.os.profiling.TracingSession;
@@ -41,8 +44,13 @@
 import android.platform.test.flag.junit.DeviceFlagsValueProvider;
 
 import androidx.test.core.app.ApplicationProvider;
+import androidx.test.platform.app.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.compatibility.common.util.SystemUtil;
+
+import com.google.errorprone.annotations.FormatMethod;
+
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -50,6 +58,7 @@
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.io.File;
 import java.util.UUID;
 
 /**
@@ -64,6 +73,8 @@
     private static final String APP_PACKAGE_NAME = "com.profiling.test";
     private static final String REQUEST_TAG = "some unique string";
 
+    private static final String OVERRIDE_DEVICE_CONFIG_INT = "device_config put %s %s %d";
+
     // Key most and least significant bits are used to generate a unique key specific to each
     // request. Key is used to pair request back to caller and callbacks so test to keep consistent.
     private static final long KEY_MOST_SIG_BITS = 456l;
@@ -76,15 +87,22 @@
     @Mock private Process mActiveTrace;
 
     private Context mContext = ApplicationProvider.getApplicationContext();
+    private Instrumentation mInstrumentation;
     private ProfilingService mProfilingService;
     private RateLimiter mRateLimiter;
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
+        mInstrumentation = InstrumentationRegistry.getInstrumentation();
         mContext = spy(ApplicationProvider.getApplicationContext());
         mProfilingService = spy(new ProfilingService(mContext));
-        mRateLimiter = spy(new RateLimiter(mContext));
+        mRateLimiter = spy(new RateLimiter(new RateLimiter.HandlerCallback() {
+            @Override
+            public Handler obtainHandler() {
+                return null;
+            }
+        }));
         doReturn(mPackageManager).when(mContext).getPackageManager();
         mProfilingService.mRateLimiter = mRateLimiter;
         doReturn(APP_PACKAGE_NAME).when(mPackageManager).getNameForUid(anyInt());
@@ -296,6 +314,117 @@
         assertFalse(callback.mResultSent);
     }
 
+    /** Test that rate limiter correctly persists and restores data. */
+    @Test
+    public void testRateLimiter_PersistAndRestore() throws Exception {
+        // Update DeviceConfig defaults to high enough limits, cost of 1, and persist frequency 0.
+        overrideRateLimiterDefaults(5, 10, 20, 50, 50, 100, 1, 1, 1, 1, 0);
+
+        // Override file path because test app context that can't access /data/system/
+        mRateLimiter.mPersistStoreDir = new File(mContext.getFilesDir(), "testdir");
+        mRateLimiter.mPersistStoreDir.mkdir();
+        mRateLimiter.mPersistFile = new File(mRateLimiter.mPersistStoreDir, "testfile");
+
+        // Remove all records
+        long currentTimeMillis = System.currentTimeMillis();
+        mRateLimiter.mPastRunsHour.removeOlderThan(currentTimeMillis);
+        mRateLimiter.mPastRunsDay.removeOlderThan(currentTimeMillis);
+        mRateLimiter.mPastRunsWeek.removeOlderThan(currentTimeMillis);
+
+        // Add some records. Since records are being added directly rather than through normal
+        // request flow, this will not trigger a persist regardless of persist frequency.
+        mRateLimiter.mPastRunsHour.add(1, 1, currentTimeMillis - 1000);
+        mRateLimiter.mPastRunsDay.add(1, 1, currentTimeMillis - 1000);
+        mRateLimiter.mPastRunsWeek.add(1, 1, currentTimeMillis - 1000);
+        mRateLimiter.mPastRunsDay.add(2, 1, currentTimeMillis - (60 * 60 * 1000) - 1000);
+        mRateLimiter.mPastRunsWeek.add(2, 1, currentTimeMillis - (60 * 60 * 1000) - 1000);
+        mRateLimiter.mPastRunsWeek.add(2, 1, currentTimeMillis - (24 * 60 * 60 * 1000) - 1000);
+
+        // Store a copy of the backing data for each type
+        RateLimiter.CollectionEntry[] hourEntriesOriginal =
+                mRateLimiter.mPastRunsHour.getEntriesCopy();
+        RateLimiter.CollectionEntry[] dayEntriesOriginal =
+                mRateLimiter.mPastRunsDay.getEntriesCopy();
+        RateLimiter.CollectionEntry[] weekEntriesOriginal =
+                mRateLimiter.mPastRunsWeek.getEntriesCopy();
+
+        // Confirm collections are correct size.
+        assertEquals(1, hourEntriesOriginal.length);
+        assertEquals(2, dayEntriesOriginal.length);
+        assertEquals(3, weekEntriesOriginal.length);
+
+        // Now persist the records to disk
+        mRateLimiter.persistToDisk();
+
+        // Remove all records again
+        currentTimeMillis = System.currentTimeMillis();
+        mRateLimiter.mPastRunsHour.removeOlderThan(currentTimeMillis);
+        mRateLimiter.mPastRunsDay.removeOlderThan(currentTimeMillis);
+        mRateLimiter.mPastRunsWeek.removeOlderThan(currentTimeMillis);
+
+        // Confirm records have been removed
+        assertEquals(0, mRateLimiter.mPastRunsHour.getEntriesCopy().length);
+        assertEquals(0, mRateLimiter.mPastRunsDay.getEntriesCopy().length);
+        assertEquals(0, mRateLimiter.mPastRunsWeek.getEntriesCopy().length);
+
+        // Now load the persisted records from disk
+        mRateLimiter.loadFromDisk();
+
+        // Finally, verify the records.
+        confirmRateLimiterEntriesEqual(hourEntriesOriginal,
+                mRateLimiter.mPastRunsHour.getEntriesCopy());
+        confirmRateLimiterEntriesEqual(dayEntriesOriginal,
+                mRateLimiter.mPastRunsDay.getEntriesCopy());
+        confirmRateLimiterEntriesEqual(weekEntriesOriginal,
+                mRateLimiter.mPastRunsWeek.getEntriesCopy());
+    }
+
+    // TODO: b/333579817 - Add more rate limiter tests
+
+    private void overrideRateLimiterDefaults(int systemHour, int processHour, int systemDay,
+            int processDay, int systemWeek, int processWeek, int costHeapDump, int costHeapProfile,
+            int costStackSampling, int costSystemTrace, int persistToDiskFrequency)
+            throws Exception {
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_SYSTEM_1_HOUR, systemHour);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_PROCESS_1_HOUR, processHour);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_SYSTEM_24_HOUR, systemDay);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_PROCESS_24_HOUR, processDay);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_SYSTEM_7_DAY, systemWeek);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.MAX_COST_PROCESS_7_DAY, processWeek);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.COST_JAVA_HEAP_DUMP, costHeapDump);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.COST_HEAP_PROFILE, costHeapProfile);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.COST_STACK_SAMPLING, costStackSampling);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.COST_SYSTEM_TRACE, costSystemTrace);
+        executeShellCmd(OVERRIDE_DEVICE_CONFIG_INT, DeviceConfigHelper.NAMESPACE,
+                DeviceConfigHelper.PERSIST_TO_DISK_FREQUENCY_MS, persistToDiskFrequency);
+    }
+
+    @FormatMethod
+    private String executeShellCmd(String cmdFormat, Object... args) throws Exception {
+        String cmd = String.format(cmdFormat, args);
+        return SystemUtil.runShellCommand(mInstrumentation, cmd);
+    }
+
+    private void confirmRateLimiterEntriesEqual(RateLimiter.CollectionEntry[] collectionOne,
+            RateLimiter.CollectionEntry[] collectionTwo) {
+        assertEquals(collectionOne.length, collectionTwo.length);
+        for (int i = 0; i < collectionOne.length; i++) {
+            assertEquals(collectionOne[i].mUid, collectionTwo[i].mUid);
+            assertEquals(collectionOne[i].mCost, collectionTwo[i].mCost);
+            assertEquals(collectionOne[i].mTimestamp, collectionTwo[i].mTimestamp);
+        }
+    }
+
     /** Confirm that all fields returned by callback match expectation. */
     private void confirmResultCallback(ProfilingResultCallback callback, String resultFile,
             long keyMostSigBits, long keyLeastSigBits, int status, String tag,