Merge "Reduce expected network usage of RemoteProvisioner"
diff --git a/src/com/android/remoteprovisioner/BootReceiver.java b/src/com/android/remoteprovisioner/BootReceiver.java
index aabbff3..bf58b9c 100644
--- a/src/com/android/remoteprovisioner/BootReceiver.java
+++ b/src/com/android/remoteprovisioner/BootReceiver.java
@@ -16,14 +16,23 @@
 
 package com.android.remoteprovisioner;
 
+import static java.lang.Math.max;
+
 import android.app.job.JobInfo;
 import android.app.job.JobScheduler;
 import android.content.BroadcastReceiver;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.Intent;
+import android.os.RemoteException;
+import android.os.ServiceManager;
+import android.security.remoteprovisioning.AttestationPoolStatus;
+import android.security.remoteprovisioning.ImplInfo;
+import android.security.remoteprovisioning.IRemoteProvisioning;
 import android.util.Log;
 
+import java.time.Duration;
+
 /**
  * A receiver class that listens for boot to be completed and then starts a recurring job that will
  * monitor the status of the attestation key pool on device, purging old certificates and requesting
@@ -31,15 +40,34 @@
  */
 public class BootReceiver extends BroadcastReceiver {
     private static final String TAG = "RemoteProvisioningBootReceiver";
+    private static final String SERVICE = "android.security.remoteprovisioning";
+
+    private static final Duration SCHEDULER_PERIOD = Duration.ofDays(1);
+
+    private static final int ESTIMATED_DOWNLOAD_BYTES_STATIC = 2300;
+    private static final int ESTIMATED_X509_CERT_BYTES = 540;
+    private static final int ESTIMATED_UPLOAD_BYTES_STATIC = 600;
+    private static final int ESTIMATED_CSR_KEY_BYTES = 44;
+
     @Override
     public void onReceive(Context context, Intent intent) {
-        Log.d(TAG, "Caught boot intent, waking up.");
+        Log.i(TAG, "Caught boot intent, waking up.");
         SettingsManager.generateAndSetId(context);
+        // An average call transmits about 500 bytes total. These calculations are for the
+        // once a month wake-up where provisioning occurs, where the expected bytes sent is closer
+        // to 8-10KB.
+        int numKeysNeeded = max(SettingsManager.getExtraSignedKeysAvailable(context),
+                                calcNumPotentialKeysToDownload());
+        int estimatedDlBytes =
+                ESTIMATED_DOWNLOAD_BYTES_STATIC + (ESTIMATED_X509_CERT_BYTES * numKeysNeeded);
+        int estimatedUploadBytes =
+                ESTIMATED_UPLOAD_BYTES_STATIC + (ESTIMATED_CSR_KEY_BYTES * numKeysNeeded);
+
         JobInfo info = new JobInfo
                 .Builder(1, new ComponentName(context, PeriodicProvisioner.class))
                 .setRequiredNetworkType(JobInfo.NETWORK_TYPE_ANY)
-                .setEstimatedNetworkBytes(1000, 1000)
-                .setPeriodic(1000 * 60 * 60 * 24)
+                .setEstimatedNetworkBytes(estimatedDlBytes, estimatedUploadBytes)
+                .setPeriodic(SCHEDULER_PERIOD.toMillis())
                 .build();
         if (((JobScheduler) context.getSystemService(Context.JOB_SCHEDULER_SERVICE)).schedule(info)
                 != JobScheduler.RESULT_SUCCESS) {
@@ -47,4 +75,31 @@
         }
     }
 
+    private int calcNumPotentialKeysToDownload() {
+        try {
+            IRemoteProvisioning binder =
+                IRemoteProvisioning.Stub.asInterface(ServiceManager.getService(SERVICE));
+            int totalKeysAssigned = 0;
+            if (binder == null) {
+                Log.e(TAG, "Binder returned null pointer to RemoteProvisioning service.");
+                return totalKeysAssigned;
+            }
+            ImplInfo[] implInfos = binder.getImplementationInfo();
+            if (implInfos == null) {
+                Log.e(TAG, "No instances of IRemotelyProvisionedComponent registered in "
+                           + SERVICE);
+                return totalKeysAssigned;
+            }
+            for (int i = 0; i < implInfos.length; i++) {
+                AttestationPoolStatus pool = binder.getPoolStatus(0, implInfos[i].secLevel);
+                if (pool != null) {
+                    totalKeysAssigned += pool.attested - pool.unassigned;
+                }
+            }
+            return totalKeysAssigned;
+        } catch (RemoteException e) {
+            Log.e(TAG, "Failure on the RemoteProvisioning backend.", e);
+            return 0;
+        }
+    }
 }
diff --git a/src/com/android/remoteprovisioner/PeriodicProvisioner.java b/src/com/android/remoteprovisioner/PeriodicProvisioner.java
index 52ab4bf..fd5e69c 100644
--- a/src/com/android/remoteprovisioner/PeriodicProvisioner.java
+++ b/src/com/android/remoteprovisioner/PeriodicProvisioner.java
@@ -16,9 +16,12 @@
 
 package com.android.remoteprovisioner;
 
+import static java.lang.Math.min;
+
 import android.app.job.JobParameters;
 import android.app.job.JobService;
 import android.content.Context;
+import android.net.ConnectivityManager;
 import android.os.RemoteException;
 import android.os.ServiceManager;
 import android.security.remoteprovisioning.AttestationPoolStatus;
@@ -36,20 +39,24 @@
 public class PeriodicProvisioner extends JobService {
 
     private static final int FAILURE_MAXIMUM = 5;
+    private static final int SAFE_CSR_BATCH_SIZE = 20;
 
     // How long to wait in between key pair generations to avoid flooding keystore with requests.
     private static final Duration KEY_GENERATION_PAUSE = Duration.ofMillis(1000);
 
+    // If the connection is metered when the job service is started, try to avoid provisioning.
+    private static final long METERED_CONNECTION_EXPIRATION_CHECK = Duration.ofDays(1).toMillis();
+
     private static final String SERVICE = "android.security.remoteprovisioning";
     private static final String TAG = "RemoteProvisioningService";
     private ProvisionerThread mProvisionerThread;
 
     /**
-     * Starts the periodic provisioning job, which will occasionally check the attestation key pool
+     * Starts the periodic provisioning job, which will check the attestation key pool
      * and provision it as necessary.
      */
     public boolean onStartJob(JobParameters params) {
-        Log.d(TAG, "Starting provisioning job");
+        Log.i(TAG, "Starting provisioning job");
         mProvisionerThread = new ProvisionerThread(params, this);
         mProvisionerThread.start();
         return true;
@@ -59,7 +66,6 @@
      * Allows the job to be stopped if need be.
      */
     public boolean onStopJob(JobParameters params) {
-        mProvisionerThread.stop();
         return false;
     }
 
@@ -74,81 +80,136 @@
 
         public void run() {
             try {
-                if (SettingsManager.getExtraSignedKeysAvailable(mContext) == 0) {
-                    // Provisioning is disabled. Check with the server if it's time to turn it back
-                    // on. If not, quit.
-                    GeekResponse check = ServerInterface.fetchGeek(mContext);
-                    if (check.numExtraAttestationKeys == 0) {
-                        jobFinished(mParams, false /* wantsReschedule */);
-                        return;
-                    }
-                }
                 IRemoteProvisioning binder =
                         IRemoteProvisioning.Stub.asInterface(ServiceManager.getService(SERVICE));
                 if (binder == null) {
                     Log.e(TAG, "Binder returned null pointer to RemoteProvisioning service.");
-                    jobFinished(mParams, true /* wantsReschedule */);
+                    jobFinished(mParams, false /* wantsReschedule */);
                     return;
                 }
+
+                ConnectivityManager cm = (ConnectivityManager) mContext.getSystemService(
+                        Context.CONNECTIVITY_SERVICE);
+                boolean isMetered = cm.isActiveNetworkMetered();
+                long expiringBy;
+                if (isMetered) {
+                    // Check a shortened duration to attempt to avoid metered connection
+                    // provisioning.
+                    expiringBy = System.currentTimeMillis() + METERED_CONNECTION_EXPIRATION_CHECK;
+                } else {
+                    expiringBy = SettingsManager.getExpiringBy(mContext)
+                                                      .plusMillis(System.currentTimeMillis())
+                                                      .toMillis();
+                }
                 ImplInfo[] implInfos = binder.getImplementationInfo();
                 if (implInfos == null) {
                     Log.e(TAG, "No instances of IRemotelyProvisionedComponent registered in "
                                + SERVICE);
-                    jobFinished(mParams, true /* wantsReschedule */);
+                    jobFinished(mParams, false /* wantsReschedule */);
                     return;
                 }
                 int[] keysNeededForSecLevel = new int[implInfos.length];
-                boolean provisioningNeeded = false;
-                for (int i = 0; i < implInfos.length; i++) {
-                    keysNeededForSecLevel[i] =
-                            generateNumKeysNeeded(binder,
-                                       SettingsManager.getExpiringBy(mContext)
-                                                      .plusMillis(System.currentTimeMillis())
-                                                      .toMillis(),
-                                       implInfos[i].secLevel);
-                    if (keysNeededForSecLevel[i] > 0) {
-                        provisioningNeeded = true;
-                    }
-                }
-                if (provisioningNeeded) {
-                    GeekResponse resp = ServerInterface.fetchGeek(mContext);
-                    if (resp == null) {
-                        if (SettingsManager.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
-                            SettingsManager.clearPreferences(mContext);
+                boolean provisioningNeeded =
+                        isProvisioningNeeded(binder, expiringBy, implInfos, keysNeededForSecLevel);
+                GeekResponse resp = null;
+                if (!provisioningNeeded) {
+                    if (!isMetered) {
+                        // So long as the connection is unmetered, go ahead and grab an updated
+                        // device configuration file.
+                        resp = ServerInterface.fetchGeek(mContext);
+                        if (!checkGeekResp(resp)) {
+                            jobFinished(mParams, false /* wantsReschedule */);
+                            return;
                         }
-                        jobFinished(mParams, true /* wantsReschedule */);
-                        return;
+                        SettingsManager.setDeviceConfig(mContext,
+                                resp.numExtraAttestationKeys,
+                                resp.timeToRefresh,
+                                resp.provisioningUrl);
+                        if (resp.numExtraAttestationKeys == 0) {
+                            binder.deleteAllKeys();
+                        }
                     }
-                    // Updates to configuration will take effect on the next check.
-                    SettingsManager.setDeviceConfig(mContext,
-                                                    resp.numExtraAttestationKeys,
-                                                    resp.timeToRefresh,
-                                                    resp.provisioningUrl);
-                    if (resp.numExtraAttestationKeys == 0) {
-                        // If the server has sent this, deactivate RKP.
-                        binder.deleteAllKeys();
-                        jobFinished(mParams, false /* wantsReschedule */);
-                        return;
-                    }
-                    for (int i = 0; i < implInfos.length; i++) {
-                        Provisioner.provisionCerts(keysNeededForSecLevel[i],
+                    jobFinished(mParams, false /* wantsReschedule */);
+                    return;
+                }
+                resp = ServerInterface.fetchGeek(mContext);
+                if (!checkGeekResp(resp)) {
+                    jobFinished(mParams, false /* wantsReschedule */);
+                    return;
+                }
+                SettingsManager.setDeviceConfig(mContext,
+                            resp.numExtraAttestationKeys,
+                            resp.timeToRefresh,
+                            resp.provisioningUrl);
+
+                if (resp.numExtraAttestationKeys == 0) {
+                    // Provisioning is disabled. Check with the server if it's time to turn it back
+                    // on. If not, quit. Avoid checking if the connection is metered. Opt instead
+                    // to just continue using the fallback factory provisioned key.
+                    binder.deleteAllKeys();
+                    jobFinished(mParams, false /* wantsReschedule */);
+                    return;
+                }
+                for (int i = 0; i < implInfos.length; i++) {
+                    // Break very large CSR requests into chunks, so as not to overwhelm the
+                    // backend.
+                    int keysToCertify = keysNeededForSecLevel[i];
+                    while (keysToCertify != 0) {
+                        int batchSize = min(keysToCertify, SAFE_CSR_BATCH_SIZE);
+                        Provisioner.provisionCerts(batchSize,
                                                    implInfos[i].secLevel,
                                                    resp.getGeekChain(implInfos[i].supportedCurve),
                                                    resp.getChallenge(),
                                                    binder,
                                                    mContext);
+                        keysToCertify -= batchSize;
                     }
                 }
                 jobFinished(mParams, false /* wantsReschedule */);
             } catch (RemoteException e) {
-                jobFinished(mParams, true /* wantsReschedule */);
+                jobFinished(mParams, false /* wantsReschedule */);
                 Log.e(TAG, "Error on the binder side during provisioning.", e);
             } catch (InterruptedException e) {
-                jobFinished(mParams, true /* wantsReschedule */);
+                jobFinished(mParams, false /* wantsReschedule */);
                 Log.e(TAG, "Provisioner thread interrupted.", e);
             }
         }
 
+        private boolean checkGeekResp(GeekResponse resp) {
+            if (resp == null) {
+                Log.e(TAG, "Failed to get a response from the server.");
+                if (SettingsManager.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
+                    Log.e(TAG, "Too many failures, resetting defaults.");
+                    SettingsManager.clearPreferences(mContext);
+                }
+                jobFinished(mParams, false /* wantsReschedule */);
+                return false;
+            }
+            return true;
+        }
+
+        private boolean isProvisioningNeeded(
+                IRemoteProvisioning binder, long expiringBy, ImplInfo[] implInfos,
+                int[] keysNeededForSecLevel)
+                throws InterruptedException, RemoteException {
+            if (implInfos == null || keysNeededForSecLevel == null
+                || keysNeededForSecLevel.length != implInfos.length) {
+                Log.e(TAG, "Invalid argument.");
+                return false;
+            }
+            boolean provisioningNeeded = false;
+            for (int i = 0; i < implInfos.length; i++) {
+                keysNeededForSecLevel[i] =
+                        generateNumKeysNeeded(binder,
+                                   expiringBy,
+                                   implInfos[i].secLevel);
+                if (keysNeededForSecLevel[i] > 0) {
+                    provisioningNeeded = true;
+                }
+            }
+            return provisioningNeeded;
+        }
+
         /**
          * This method will generate and bundle up keys for signing to make sure that there will be
          * enough keys available for use by the system when current keys expire.
@@ -164,12 +225,19 @@
                 throws InterruptedException, RemoteException {
             AttestationPoolStatus pool = binder.getPoolStatus(expiringBy, secLevel);
             int unattestedKeys = pool.total - pool.attested;
-            int validKeys = pool.attested - pool.expiring;
             int keysInUse = pool.attested - pool.unassigned;
             int totalSignedKeys = keysInUse + SettingsManager.getExtraSignedKeysAvailable(mContext);
             int generated;
+            // If nothing is expiring, and the amount of available unassigned keys is sufficient,
+            // then do nothing. Otherwise, generate the complete amount of totalSignedKeys. It will
+            // reduce network usage if the app just provisions an entire new batch in one go, rather
+            // than consistently grabbing just a few at a time as the expiration dates become
+            // misaligned.
+            if (pool.expiring > pool.unassigned && pool.attested == totalSignedKeys) {
+                return 0;
+            }
             for (generated = 0;
-                    generated + unattestedKeys + validKeys < totalSignedKeys; generated++) {
+                    generated + unattestedKeys < totalSignedKeys; generated++) {
                 binder.generateKeyPair(false /* isTestMode */, secLevel);
                 // Prioritize provisioning if there are no keys available. No keys being available
                 // indicates that this is the first time a device is being brought online.
@@ -177,7 +245,7 @@
                     Thread.sleep(KEY_GENERATION_PAUSE.toMillis());
                 }
             }
-            if (totalSignedKeys - validKeys > 0) {
+            if (totalSignedKeys > 0) {
                 return generated + unattestedKeys;
             }
             return 0;
diff --git a/src/com/android/remoteprovisioner/Provisioner.java b/src/com/android/remoteprovisioner/Provisioner.java
index 91341fc..06a7f4d 100644
--- a/src/com/android/remoteprovisioner/Provisioner.java
+++ b/src/com/android/remoteprovisioner/Provisioner.java
@@ -81,6 +81,10 @@
                                                   challenge,
                                                   protectedData.protectedData,
                                                   macedKeysToSign);
+        if (certificateRequest == null) {
+            Log.e(TAG, "Failed to serialize the payload generated by keystore.");
+            return 0;
+        }
         List<byte[]> certChains = ServerInterface.requestSignedCertificates(context,
                         certificateRequest, challenge);
         if (certChains == null) {
@@ -101,6 +105,10 @@
             // getTime returns the time in *milliseconds* since the epoch.
             long expirationDate = cert.getNotAfter().getTime();
             byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(cert);
+            if (rawPublicKey == null) {
+                Log.e(TAG, "Skipping malformed public key.");
+                continue;
+            }
             try {
                 if (SystemInterface.provisionCertChain(rawPublicKey, cert.getEncoded(), certChain,
                                                        expirationDate, secLevel, binder)) {
diff --git a/src/com/android/remoteprovisioner/SystemInterface.java b/src/com/android/remoteprovisioner/SystemInterface.java
index 79f4cea..67ab028 100644
--- a/src/com/android/remoteprovisioner/SystemInterface.java
+++ b/src/com/android/remoteprovisioner/SystemInterface.java
@@ -80,6 +80,10 @@
                                                         secLevel,
                                                         protectedData,
                                                         deviceInfo);
+            if (macedPublicKeys == null) {
+                Log.e(TAG, "Keystore didn't generate a CSR successfully.");
+                return null;
+            }
             ByteArrayInputStream bais = new ByteArrayInputStream(macedPublicKeys);
             List<DataItem> dataItems = new CborDecoder(bais).decode();
             List<DataItem> macInfo = ((Array) dataItems.get(0)).getDataItems();
diff --git a/src/com/android/remoteprovisioner/X509Utils.java b/src/com/android/remoteprovisioner/X509Utils.java
index 60fee4a..d33d573 100644
--- a/src/com/android/remoteprovisioner/X509Utils.java
+++ b/src/com/android/remoteprovisioner/X509Utils.java
@@ -16,8 +16,11 @@
 
 package com.android.remoteprovisioner;
 
+import android.util.Log;
+
 import java.io.ByteArrayInputStream;
 import java.math.BigInteger;
+import java.security.PublicKey;
 import java.security.cert.Certificate;
 import java.security.cert.CertificateException;
 import java.security.cert.CertificateFactory;
@@ -30,6 +33,8 @@
  */
 public class X509Utils {
 
+    private static final String TAG = "RemoteProvisionerX509Utils";
+
     /**
      * Takes a byte array composed of DER encoded certificates and returns the X.509 certificates
      * contained within as an X509Certificate array.
@@ -47,6 +52,11 @@
      * the certificate chain to the proper key when passed into the keystore database.
      */
     public static byte[] getAndFormatRawPublicKey(X509Certificate cert) {
+        PublicKey pubKey = cert.getPublicKey();
+        if (!(pubKey instanceof ECPublicKey)) {
+            Log.e(TAG, "Certificate public key is not an instance of ECPublicKey");
+            return null;
+        }
         ECPublicKey key = (ECPublicKey) cert.getPublicKey();
         // Remote key provisioning internally supports the default, uncompressed public key
         // format for ECDSA. This defines the format as (s | x | y), where s is the byte
diff --git a/src/com/android/remoteprovisioner/service/GenerateRkpKeyService.java b/src/com/android/remoteprovisioner/service/GenerateRkpKeyService.java
index b3292b5..73c83b1 100644
--- a/src/com/android/remoteprovisioner/service/GenerateRkpKeyService.java
+++ b/src/com/android/remoteprovisioner/service/GenerateRkpKeyService.java
@@ -86,10 +86,11 @@
                     break;
                 }
             }
-            // If there are no unassigned keys, go ahead and provision some. If there are no keys
-            // at all on system, this implies that it is a hybrid rkp/factory-provisioned system
-            // that has turned off RKP. In that case, do not provision.
-            if (pool.unassigned == 0 && pool.total != 0) {
+            // If there are no unassigned keys, go ahead and provision some. If there are no
+            // attested keys at all on the system, this implies that it is a hybrid
+            // rkp/factory-provisioned system that has turned off RKP. In that case, do
+            // not provision.
+            if (pool.unassigned == 0 && pool.attested != 0) {
                 Log.i(TAG, "All signed keys are currently in use, provisioning more.");
                 Context context = getApplicationContext();
                 int keysToProvision = SettingsManager.getExtraSignedKeysAvailable(context);
diff --git a/tests/unittests/src/com/android/remoteprovisioner/unittest/ServerToSystemTest.java b/tests/unittests/src/com/android/remoteprovisioner/unittest/ServerToSystemTest.java
index e8ee610..94206af 100644
--- a/tests/unittests/src/com/android/remoteprovisioner/unittest/ServerToSystemTest.java
+++ b/tests/unittests/src/com/android/remoteprovisioner/unittest/ServerToSystemTest.java
@@ -163,7 +163,9 @@
                                         "Not even a URL" /* url */);
         // Even if there is an unsigned key hanging around, fallback should still occur.
         Certificate[] fallbackKeyCerts2 = generateKeyStoreKey("test3");
-        assertEquals(1, SettingsManager.getFailureCounter(sContext));
+        // Due to there being no attested keys in the pool, the provisioning service should not
+        // have even attempted to provision more certificates.
+        assertEquals(0, SettingsManager.getFailureCounter(sContext));
         assertTrue(fallbackKeyCerts1.length == fallbackKeyCerts2.length);
         for (int i = 1; i < fallbackKeyCerts1.length; i++) {
             assertArrayEquals("Cert: " + i, fallbackKeyCerts1[i].getEncoded(),