wifi: propogate certification event to the framework

Send CA certificate to the wifi framework for a network
which is configured to use Trust On First Use.

Bug: 196180536
Test: atest FrameworksWifiTests
Change-Id: I9024a39e744641d94af2cfdb559953e999ab217e
diff --git a/service/java/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImpl.java b/service/java/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImpl.java
index 0fb7514..61453b6 100644
--- a/service/java/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImpl.java
+++ b/service/java/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImpl.java
@@ -25,6 +25,17 @@
 
 import com.android.server.wifi.util.NativeUtil;
 
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.CharacterCodingException;
+import java.nio.charset.CharsetDecoder;
+import java.nio.charset.StandardCharsets;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+
 class SupplicantStaNetworkCallbackAidlImpl extends ISupplicantStaNetworkCallback.Stub {
     private final SupplicantStaNetworkHalAidlImpl mNetworkHal;
     /**
@@ -111,4 +122,95 @@
                     mIfaceName, mFrameworkNetworkId, frameworkBits);
         }
     }
+
+    private String byteArrayToString(byte[] byteArray) {
+        // Not a valid bytes for a string
+        if (byteArray == null) return null;
+        CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder();
+        try {
+            CharBuffer decoded = decoder.decode(ByteBuffer.wrap(byteArray));
+            return decoded.toString();
+        } catch (CharacterCodingException cce) {
+        }
+        return null;
+    }
+
+    @Override
+    public void onServerCertificateAvailable(
+            int depth,
+            byte[] subjectBytes,
+            byte[] certHashBytes,
+            byte[] certBytes) {
+        synchronized (mLock) {
+            // OpenSSL default maximum depth is 100.
+            if (depth < 0 || depth > 100) {
+                mNetworkHal.logCallback("onServerCertificateAvailable: invalid depth " + depth);
+                return;
+            }
+            if (null == subjectBytes) {
+                mNetworkHal.logCallback("onServerCertificateAvailable: subject is null.");
+                return;
+            }
+            if (null == certHashBytes) {
+                mNetworkHal.logCallback("onServerCertificateAvailable: cert hash is null.");
+                return;
+            }
+            if (null == certBytes) {
+                mNetworkHal.logCallback("onServerCertificateAvailable: cert is null.");
+                return;
+            }
+
+            mNetworkHal.logCallback("onServerCertificateAvailable: "
+                    + " depth=" + depth
+                    + " subjectBytes size=" + subjectBytes.length
+                    + " certHashBytes size=" + certHashBytes.length
+                    + " certBytes size=" + certBytes.length);
+
+            if (0 == certHashBytes.length) return;
+            if (0 == certBytes.length) return;
+
+            String subject = byteArrayToString(subjectBytes);
+            if (null == subject) {
+                mNetworkHal.logCallback(
+                        "onServerCertificateAvailable: cannot convert subject bytes to string.");
+                return;
+            }
+            String certHash = byteArrayToString(certHashBytes);
+            if (null == subject) {
+                mNetworkHal.logCallback(
+                        "onServerCertificateAvailable: cannot convert cert hash bytes to string.");
+                return;
+            }
+            X509Certificate cert = null;
+            try {
+                CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
+                InputStream in = new ByteArrayInputStream(certBytes);
+                cert = (X509Certificate) certFactory.generateCertificate(in);
+            } catch (CertificateException e) {
+                cert = null;
+                mNetworkHal.logCallback(
+                        "onServerCertificateAvailable: "
+                        + "Failed to get instance for CertificateFactory: " + e);
+            } catch (IllegalArgumentException e) {
+                cert = null;
+                mNetworkHal.logCallback(
+                        "onServerCertificateAvailable: Failed to decode the data: " + e);
+            }
+            if (null == cert) {
+                mNetworkHal.logCallback(
+                        "onServerCertificateAvailable: Failed to read certificate.");
+                return;
+            }
+            // Not a CA certificate, ignore it.
+            if (cert.getBasicConstraints() < 0) return;
+
+            mNetworkHal.logCallback("onServerCertificateAvailable:"
+                    + " depth=" + depth
+                    + " subject=" + subject
+                    + " certHash=" + certHash
+                    + " cert=" + cert);
+            mWifiMonitor.broadcastCertificationEvent(
+                    mIfaceName, mFrameworkNetworkId, mSsid, cert);
+        }
+    }
 }
diff --git a/service/java/com/android/server/wifi/WifiMonitor.java b/service/java/com/android/server/wifi/WifiMonitor.java
index 5fe6c46..eaf3546 100644
--- a/service/java/com/android/server/wifi/WifiMonitor.java
+++ b/service/java/com/android/server/wifi/WifiMonitor.java
@@ -37,6 +37,7 @@
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.security.cert.X509Certificate;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
@@ -104,6 +105,8 @@
     /* Transition Disable Indication */
     public static final int TRANSITION_DISABLE_INDICATION        = BASE + 72;
 
+    /* Trust On First Use Root CA Certification */
+    public static final int TOFU_ROOT_CA_CERTIFICATE             = BASE + 73;
 
     /* WPS config errrors */
     private static final int CONFIG_MULTIPLE_PBC_DETECTED = 12;
@@ -596,4 +599,17 @@
     public void broadcastNetworkNotFoundEvent(String iface, String ssid) {
         sendMessage(iface, NETWORK_NOT_FOUND_EVENT, ssid);
     }
+
+    /**
+     * Broadcast the certification event which takes place during TOFU process.
+     *
+     * @param iface Name of iface on which this occurred.
+     * @param networkId ID of the network in wpa_supplicant.
+     * @param ssid SSID of the network.
+     * @param cert the certificate data.
+     */
+    public void broadcastCertificationEvent(String iface, int networkId, String ssid,
+            X509Certificate cert) {
+        sendMessage(iface, TOFU_ROOT_CA_CERTIFICATE, networkId, 0, cert);
+    }
 }
diff --git a/service/tests/wifitests/src/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImplTest.java b/service/tests/wifitests/src/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImplTest.java
new file mode 100644
index 0000000..11b3eff
--- /dev/null
+++ b/service/tests/wifitests/src/com/android/server/wifi/SupplicantStaNetworkCallbackAidlImplTest.java
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.server.wifi;
+
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.validateMockitoUsage;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
+
+import androidx.test.filters.SmallTest;
+
+import com.android.dx.mockito.inline.extended.ExtendedMockito;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.MockitoSession;
+import org.mockito.quality.Strictness;
+
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+
+/**
+ * Unit tests for SupplicantStaNetworkHalAidlImpl
+ */
+@SmallTest
+public class SupplicantStaNetworkCallbackAidlImplTest extends WifiBaseTest {
+    private static final int TEST_NETWORK_ID = 9;
+    private static final String TEST_SSID = "TestSsid";
+    private static final String TEST_INTERFACE = "wlan1";
+
+    @Mock private SupplicantStaNetworkHalAidlImpl mSupplicantStaNetworkHalAidlImpl;
+    @Mock private Object mLock;
+    @Mock private WifiMonitor mWifiMonitor;
+    @Mock private CertificateFactory mCertificateFactory;
+    @Mock private X509Certificate mX509Certificate;
+
+    private MockitoSession mSession;
+    private SupplicantStaNetworkCallbackAidlImpl mSupplicantStaNetworkCallbackAidlImpl;
+
+    @Before
+    public void setUp() throws Exception {
+        MockitoAnnotations.initMocks(this);
+        // static mocking
+        mSession = ExtendedMockito.mockitoSession()
+                .mockStatic(CertificateFactory.class, withSettings().lenient())
+                .strictness(Strictness.LENIENT)
+                .startMocking();
+        when(CertificateFactory.getInstance(any())).thenReturn(mCertificateFactory);
+        when(mCertificateFactory.generateCertificate(any())).thenReturn(mX509Certificate);
+        when(mX509Certificate.getBasicConstraints()).thenReturn(0);
+
+        mSupplicantStaNetworkCallbackAidlImpl =  new SupplicantStaNetworkCallbackAidlImpl(
+                mSupplicantStaNetworkHalAidlImpl, TEST_NETWORK_ID, TEST_SSID, TEST_INTERFACE,
+                mLock, mWifiMonitor);
+    }
+
+    /**
+     * Called after each test
+     */
+    @After
+    public void cleanup() {
+        validateMockitoUsage();
+        if (mSession != null) {
+            mSession.finishMocking();
+        }
+    }
+
+    /** verify onServerCertificateAvailable sunny case. */
+    @Test
+    public void testOnCertificateSuccess() throws Exception {
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, "subject".getBytes(), "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor).broadcastCertificationEvent(
+                eq(TEST_INTERFACE), eq(TEST_NETWORK_ID),
+                eq(TEST_SSID), eq(mX509Certificate));
+    }
+
+    /** verify onServerCertificateAvailable with illegal arguments. */
+    @Test
+    public void testOnCertificateIllegalInput() throws Exception {
+        // Illegal argument: negative depth.
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                -1, "subject".getBytes(), "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+
+        // Illegal argument: depth over 100.
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                101, "subject".getBytes(), "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+
+        // Illegal argument: null subject
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, null, "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+
+        // Illegal argument: null cert hash
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, "subject".getBytes(), null, "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+
+        // Illegal argument: null cert.
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, "subject".getBytes(), "certHash".getBytes(), null);
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+    }
+
+    /** verify onServerCertificateAvailable with CertificateException. */
+    @Test
+    public void testOnCertificateWithCertificateException() throws Exception {
+        doThrow(new CertificateException())
+                .when(mCertificateFactory).generateCertificate(any());
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, "subject".getBytes(), "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+    }
+
+    /** verify onServerCertificateAvailable with IllegalArgumentException. */
+    @Test
+    public void testOnCertificateWithIllegalArgumentException() throws Exception {
+        doThrow(new IllegalArgumentException())
+                .when(mCertificateFactory).generateCertificate(any());
+        mSupplicantStaNetworkCallbackAidlImpl.onServerCertificateAvailable(
+                0, "subject".getBytes(), "certHash".getBytes(), "cert".getBytes());
+        verify(mWifiMonitor, never()).broadcastCertificationEvent(
+                any(), anyInt(), any(), any());
+    }
+}
diff --git a/service/tests/wifitests/src/com/android/server/wifi/WifiMonitorTest.java b/service/tests/wifitests/src/com/android/server/wifi/WifiMonitorTest.java
index 5902af6..f8b430b 100644
--- a/service/tests/wifitests/src/com/android/server/wifi/WifiMonitorTest.java
+++ b/service/tests/wifitests/src/com/android/server/wifi/WifiMonitorTest.java
@@ -47,6 +47,8 @@
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
+import java.security.cert.X509Certificate;
+
 /**
  * Unit tests for {@link com.android.server.wifi.WifiMonitor}.
  */
@@ -726,4 +728,23 @@
         String ssid = (String) messageCaptor.getValue().obj;
         assertEquals(SSID, ssid);
     }
+
+    /**
+     * Broadcast Certification event.
+     */
+    @Test
+    public void testBroadcastCertificateEvent() {
+        mWifiMonitor.registerHandler(
+                WLAN_IFACE_NAME, WifiMonitor.TOFU_ROOT_CA_CERTIFICATE, mHandlerSpy);
+        mWifiMonitor.broadcastCertificationEvent(
+                WLAN_IFACE_NAME, NETWORK_ID, SSID, FakeKeys.CA_CERT0);
+        mLooper.dispatchAll();
+
+        ArgumentCaptor<Message> messageCaptor = ArgumentCaptor.forClass(Message.class);
+        verify(mHandlerSpy).handleMessage(messageCaptor.capture());
+        assertEquals(WifiMonitor.TOFU_ROOT_CA_CERTIFICATE, messageCaptor.getValue().what);
+        assertEquals(NETWORK_ID, messageCaptor.getValue().arg1);
+        X509Certificate cert = (X509Certificate) messageCaptor.getValue().obj;
+        assertEquals(FakeKeys.CA_CERT0, cert);
+    }
 }