blob: 4acab0de66224bb9a7be1facee2a458f1b5173f1 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.rkpdapp.unittest;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import android.content.Context;
import android.os.RemoteException;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.android.rkpdapp.GeekResponse;
import com.android.rkpdapp.RkpdException;
import com.android.rkpdapp.database.ProvisionedKey;
import com.android.rkpdapp.database.ProvisionedKeyDao;
import com.android.rkpdapp.database.RkpKey;
import com.android.rkpdapp.database.RkpdDatabase;
import com.android.rkpdapp.interfaces.SystemInterface;
import com.android.rkpdapp.metrics.ProvisioningAttempt;
import com.android.rkpdapp.provisioner.Provisioner;
import com.android.rkpdapp.testutil.FakeRkpServer;
import com.android.rkpdapp.utils.Settings;
import com.google.crypto.tink.subtle.Random;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.security.KeyPair;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import co.nstant.in.cbor.model.Array;
@RunWith(AndroidJUnit4.class)
public class ProvisionerTest {
private static final byte[] FAKE_RKP_KEY_BLOB_1 = Random.randBytes(10);
private static final byte[] FAKE_RKP_KEY_BLOB_2 = Random.randBytes(10);
private static final byte[] FAKE_RKP_KEY_BLOB_3 = Random.randBytes(10);
private static final Instant NOW = Instant.now().truncatedTo(ChronoUnit.SECONDS);
private static final RkpKey FAKE_RKP_KEY = new RkpKey(FAKE_RKP_KEY_BLOB_1, new byte[2],
new Array(), "hal", new byte[3]);
private static Context sContext;
private Provisioner mProvisioner;
private ProvisionedKeyDao mKeyDao;
@BeforeClass
public static void init() {
sContext = ApplicationProvider.getApplicationContext();
}
@Before
public void setUp() {
Settings.clearPreferences(sContext);
mKeyDao = RkpdDatabase.getDatabase(sContext).provisionedKeyDao();
mKeyDao.deleteAllKeys();
mProvisioner = new Provisioner(sContext, mKeyDao, false);
}
@After
public void tearDown() {
Settings.clearPreferences(sContext);
}
@Test
public void testProvisionerUsesCorrectBatchSize() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 20, Duration.ofDays(1), server.getUrl());
final int batchSize = 13;
ProvisioningAttempt atom = ProvisioningAttempt.createScheduledAttemptMetrics(sContext);
SystemInterface mockSystem = mock(SystemInterface.class);
doReturn(batchSize).when(mockSystem).getBatchSize();
doReturn(FAKE_RKP_KEY).when(mockSystem).generateKey(eq(atom));
doReturn(new byte[1]).when(mockSystem).generateCsr(eq(atom), notNull(), notNull());
GeekResponse geekResponse = new GeekResponse();
geekResponse.setChallenge(new byte[1]);
mProvisioner.provisionKeys(atom, mockSystem, geekResponse);
verify(mockSystem).generateCsr(any(), any(),
argThat(keysGenerated -> keysGenerated.size() == 13));
verify(mockSystem).generateCsr(any(), any(),
argThat(keysGenerated -> keysGenerated.size() == 7));
}
}
@Test
public void testProvisionerHandlesExceptionOnGetBatchSize() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 20, Duration.ofDays(1), server.getUrl());
ProvisioningAttempt atom = ProvisioningAttempt.createScheduledAttemptMetrics(sContext);
SystemInterface mockSystem = mock(SystemInterface.class);
doThrow(new RemoteException()).when(mockSystem).getBatchSize();
doReturn(FAKE_RKP_KEY).when(mockSystem).generateKey(eq(atom));
GeekResponse geekResponse = new GeekResponse();
geekResponse.setChallenge(new byte[1]);
assertThrows(RkpdException.class, () ->
mProvisioner.provisionKeys(atom, mockSystem, geekResponse));
}
}
private byte[] generateCertificateChain(Instant rootCreationTime, Instant leafCreationTime)
throws Exception {
KeyPair rootKey = Utils.generateEcdsaKeyPair();
KeyPair leafKey = Utils.generateEcdsaKeyPair();
// Just so that we don't get expired certificates by default.
Instant expirationTime = NOW.plus(Duration.ofDays(1));
byte[] rootCertEncoded = Utils.signPublicKey(rootKey, rootKey.getPublic(), rootCreationTime,
expirationTime).getEncoded();
byte[] leafCertEncoded = Utils.signPublicKey(rootKey, leafKey.getPublic(), leafCreationTime,
expirationTime).getEncoded();
byte[] encodedCertChain = new byte[leafCertEncoded.length + rootCertEncoded.length];
System.arraycopy(leafCertEncoded, 0, encodedCertChain, 0, leafCertEncoded.length);
System.arraycopy(rootCertEncoded, 0, encodedCertChain, leafCertEncoded.length,
rootCertEncoded.length);
return encodedCertChain;
}
private void setUpClearAttestationKeyTests(Instant failureStart, Instant failureEnd)
throws Exception {
Instant expiration = NOW.plus(Duration.ofDays(1));
Instant rootCreationTime = failureStart.minus(Duration.ofDays(10));
// add a fake key to the database with certificate time that is in the bad cert range.
ProvisionedKey keyBeforeFailure = new ProvisionedKey(
FAKE_RKP_KEY_BLOB_1,
"fakeHal1",
new byte[0],
generateCertificateChain(rootCreationTime, failureStart.minus(Duration.ofDays(1))),
expiration);
ProvisionedKey keyBadCert = new ProvisionedKey(
FAKE_RKP_KEY_BLOB_2,
"fakeHal2",
new byte[0],
generateCertificateChain(rootCreationTime, failureStart.plus(Duration.ofHours(1))),
expiration);
ProvisionedKey keyAfterFailure = new ProvisionedKey(
FAKE_RKP_KEY_BLOB_3,
"fakeHal3",
new byte[0],
generateCertificateChain(rootCreationTime, failureEnd.plus(Duration.ofDays(1))),
expiration);
mKeyDao.insertKeys(List.of(keyBeforeFailure, keyBadCert, keyAfterFailure));
}
@Test
public void testProvisionerClearsAttestationKeysOnResponse() throws Exception {
Instant failureTimeStart = NOW.minus(Duration.ofDays(5));
Instant failureTimeEnd = NOW.minus(Duration.ofDays(2));
setUpClearAttestationKeyTests(failureTimeStart, failureTimeEnd);
assertThat(mKeyDao.getAllKeys()).hasSize(3);
GeekResponse resp = new GeekResponse();
resp.lastBadCertTimeStart = failureTimeStart;
resp.lastBadCertTimeEnd = failureTimeEnd;
mProvisioner.clearBadAttestationKeys(resp);
assertThat(mKeyDao.getAllKeys()).hasSize(2);
}
@Test
public void testProvisionerClearsAttestationKeysOnlyOnce() throws Exception {
Instant failureTimeStart = NOW.minus(Duration.ofDays(5));
Instant failureTimeEnd = NOW.minus(Duration.ofDays(2));
setUpClearAttestationKeyTests(failureTimeStart, failureTimeEnd);
assertThat(mKeyDao.getAllKeys()).hasSize(3);
GeekResponse resp = new GeekResponse();
resp.lastBadCertTimeStart = failureTimeStart;
resp.lastBadCertTimeEnd = failureTimeEnd;
Settings.setLastBadCertTimeRange(sContext, failureTimeStart, failureTimeEnd);
mProvisioner.clearBadAttestationKeys(resp);
assertThat(mKeyDao.getAllKeys()).hasSize(3);
}
}