blob: 2119c0c1757519a16cf11d4cc509a46128960bae [file] [log] [blame]
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.rkpdapp.provisioner;
import android.content.Context;
import android.os.RemoteException;
import android.util.Log;
import com.android.rkpdapp.GeekResponse;
import com.android.rkpdapp.RkpdException;
import com.android.rkpdapp.database.InstantConverter;
import com.android.rkpdapp.database.ProvisionedKey;
import com.android.rkpdapp.database.ProvisionedKeyDao;
import com.android.rkpdapp.database.RkpKey;
import com.android.rkpdapp.interfaces.ServerInterface;
import com.android.rkpdapp.interfaces.SystemInterface;
import com.android.rkpdapp.metrics.ProvisioningAttempt;
import com.android.rkpdapp.utils.Settings;
import com.android.rkpdapp.utils.StatsProcessor;
import com.android.rkpdapp.utils.X509Utils;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import co.nstant.in.cbor.CborException;
/**
* Provides an easy package to run the provisioning process from start to finish, interfacing
* with the system interface and the server backend in order to provision attestation certificates
* to the device.
*/
public class Provisioner {
private static final String TAG = "RkpdProvisioner";
private static final int FAILURE_MAXIMUM = 5;
private static final Object provisionKeysLock = new Object();
private final Context mContext;
private final ProvisionedKeyDao mKeyDao;
private final boolean mIsAsync;
public Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao,
boolean isAsync) {
mContext = applicationContext;
mKeyDao = keyDao;
mIsAsync = isAsync;
}
/**
* Check to see if we need to perform provisioning or not for the given
* IRemotelyProvisionedComponent.
* @param serviceName the name of the remotely provisioned component to be provisioned
* @return true if the remotely provisioned component requires more keys, false if the pool
* of available keys is healthy.
*/
public boolean isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName) {
return calculateKeysRequired(metrics, serviceName) > 0;
}
/**
* Generate, sign and store remotely provisioned keys.
*/
public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface,
GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException {
synchronized (provisionKeysLock) {
try {
int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
if (keysRequired == 0) {
metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
return;
}
List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
checkForInterrupts();
List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
geekResponse);
checkForInterrupts();
List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);
mKeyDao.insertKeys(keys);
Log.i(TAG, "Total provisioned keys: " + keys.size());
metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
} catch (InterruptedException e) {
metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
throw e;
} catch (RkpdException e) {
if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
Log.e(TAG, "Too many failures, resetting defaults.");
Settings.resetDefaultConfig(mContext);
}
// Rethrow to provide failure signal to caller
throw e;
}
}
}
private List<RkpKey> generateKeys(ProvisioningAttempt metrics, int numKeysRequired,
SystemInterface systemInterface)
throws CborException, RkpdException, InterruptedException {
List<RkpKey> keyArray = new ArrayList<>(numKeysRequired);
checkForInterrupts();
for (long i = 0; i < numKeysRequired; i++) {
keyArray.add(systemInterface.generateKey(metrics));
}
return keyArray;
}
private List<byte[]> fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated,
SystemInterface systemInterface, GeekResponse geekResponse)
throws RkpdException, CborException, InterruptedException {
int provisionedSoFar = 0;
List<byte[]> certChains = new ArrayList<>(keysGenerated.size());
int maxBatchSize;
try {
maxBatchSize = systemInterface.getBatchSize();
} catch (RemoteException e) {
throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
"Error getting batch size from the system", e);
}
while (provisionedSoFar != keysGenerated.size()) {
int batchSize = Math.min(keysGenerated.size() - provisionedSoFar, maxBatchSize);
certChains.addAll(batchProvision(metrics, systemInterface, geekResponse,
keysGenerated.subList(provisionedSoFar, batchSize + provisionedSoFar)));
provisionedSoFar += batchSize;
}
return certChains;
}
private List<byte[]> batchProvision(ProvisioningAttempt metrics,
SystemInterface systemInterface,
GeekResponse response, List<RkpKey> keysGenerated)
throws RkpdException, CborException, InterruptedException {
int batch_size = keysGenerated.size();
if (batch_size < 1) {
throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
"Request at least 1 key to be signed. Num requested: " + batch_size);
}
byte[] certRequest = systemInterface.generateCsr(metrics, response, keysGenerated);
if (certRequest == null) {
throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
"Failed to serialize payload");
}
return new ServerInterface(mContext, mIsAsync).requestSignedCertificates(certRequest,
response.getChallenge(), metrics);
}
private List<ProvisionedKey> associateCertsWithKeys(List<byte[]> certChains,
List<RkpKey> keysGenerated) throws RkpdException {
List<ProvisionedKey> provisionedKeys = new ArrayList<>();
for (byte[] chain : certChains) {
X509Certificate[] certChain = X509Utils.formatX509Certs(chain);
X509Certificate leafCertificate = certChain[0];
long expirationDate = X509Utils.getExpirationTimeForCertificateChain(certChain)
.toInstant().toEpochMilli();
byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(leafCertificate);
if (rawPublicKey == null) {
Log.e(TAG, "Skipping malformed public key.");
continue;
}
for (RkpKey key : keysGenerated) {
if (Arrays.equals(key.getPublicKey(), rawPublicKey)) {
provisionedKeys.add(key.generateProvisionedKey(chain,
InstantConverter.fromTimestamp(expirationDate)));
keysGenerated.remove(key);
break;
}
}
}
return provisionedKeys;
}
/**
* Calculate the number of keys to be provisioned.
*/
private int calculateKeysRequired(ProvisioningAttempt metrics, String serviceName) {
int numExtraAttestationKeys = Settings.getExtraSignedKeysAvailable(mContext);
Instant expirationTime = Settings.getExpirationTime(mContext);
StatsProcessor.PoolStats poolStats = StatsProcessor.processPool(mKeyDao, serviceName,
numExtraAttestationKeys, expirationTime);
metrics.setIsKeyPoolEmpty(poolStats.keysUnassigned == 0);
return poolStats.keysToGenerate;
}
private void checkForInterrupts() throws InterruptedException {
if (Thread.interrupted()) {
throw new InterruptedException();
}
}
/**
* Clears bad attestation keys on the basis of information provided in the FetchGeek response.
*/
public void clearBadAttestationKeys(GeekResponse resp) {
if (resp.lastBadCertTimeStart == null || resp.lastBadCertTimeEnd == null) {
// if there is no time sent, no need to do anything.
return;
}
if (resp.lastBadCertTimeStart.equals(Settings.getLastBadCertTimeStart(mContext))
&& resp.lastBadCertTimeEnd.equals(Settings.getLastBadCertTimeEnd(mContext))) {
// if the time is same as already stored version, no need to do anything.
return;
}
// clear the attestation keys on the basis of time.
checkAndDeleteBadKeys(resp.lastBadCertTimeStart, resp.lastBadCertTimeEnd);
// store the time.
Settings.setLastBadCertTimeRange(mContext, resp.lastBadCertTimeStart,
resp.lastBadCertTimeEnd);
}
private void checkAndDeleteBadKeys(Instant startTime, Instant endTime) {
try {
List<ProvisionedKey> allKeys = mKeyDao.getAllKeys();
for (int i = 0; i < allKeys.size(); i++) {
ProvisionedKey key = allKeys.get(i);
X509Certificate[] certChain = X509Utils.formatX509Certs(key.certificateChain);
X509Certificate leafCertificate = certChain[0];
Instant creationTime = leafCertificate.getNotBefore().toInstant()
.truncatedTo(ChronoUnit.MILLIS);
if (!creationTime.isBefore(startTime) && !creationTime.isAfter(endTime)) {
mKeyDao.deleteKey(key.keyBlob);
}
}
} catch (RkpdException ex) {
Log.e(TAG, "Could not convert certificate chain to X509 certificates.", ex);
}
}
}