Try reconnecting to wifi if it doesn't automatically connect.

When meteredness of wifi changes during test, wifi disconnects.
Try reconnecting if it automatically connects after this.
Also, change the way we update metereness so that we can
use callback mechanism to wait for the state change instead of
polling for it regularly.

Bug: 181686645
Test: atest ./tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
Ignore-AOSP-First: Submitting internally first to avoid merge conflicts
Change-Id: I31fb127ef333d39fe4697043876c7cef15d525e3
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
index d7981c9..9d1d418 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
@@ -22,12 +22,13 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.wifi.WifiConfiguration.METERED_OVERRIDE_METERED;
+import static android.net.wifi.WifiConfiguration.METERED_OVERRIDE_NONE;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 import static com.android.cts.net.hostside.AbstractRestrictBackgroundNetworkTestCase.TAG;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
@@ -35,20 +36,22 @@
 
 import android.app.ActivityManager;
 import android.app.Instrumentation;
+import android.app.UiAutomation;
 import android.content.Context;
 import android.location.LocationManager;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.wifi.WifiConfiguration;
 import android.net.wifi.WifiManager;
+import android.net.wifi.WifiManager.ActionListener;
 import android.os.PersistableBundle;
 import android.os.Process;
 import android.os.UserHandle;
 import android.telephony.CarrierConfigManager;
 import android.telephony.SubscriptionManager;
 import android.telephony.data.ApnSetting;
-import android.text.TextUtils;
 import android.util.Log;
 
 import androidx.test.platform.app.InstrumentationRegistry;
@@ -58,7 +61,12 @@
 import com.android.compatibility.common.util.ShellIdentityUtils;
 import com.android.compatibility.common.util.ThrowingRunnable;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 
 public class NetworkPolicyTestUtils {
@@ -189,19 +197,14 @@
     }
 
     private static String getWifiSsid() {
-        final boolean isLocationEnabled = isLocationEnabled();
+        final UiAutomation uiAutomation = getInstrumentation().getUiAutomation();
         try {
-            if (!isLocationEnabled) {
-                setLocationEnabled(true);
-            }
-            final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
+            uiAutomation.adoptShellPermissionIdentity();
+            final String ssid = getWifiManager().getConnectionInfo().getSSID();
             assertNotEquals(WifiManager.UNKNOWN_SSID, ssid);
             return ssid;
         } finally {
-            // Reset the location enabled state
-            if (!isLocationEnabled) {
-                setLocationEnabled(false);
-            }
+            uiAutomation.dropShellPermissionIdentity();
         }
     }
 
@@ -212,18 +215,71 @@
     }
 
     private static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
-        assertFalse("SSID should not be empty", TextUtils.isEmpty(ssid));
-        final String cmd = "cmd netpolicy set metered-network " + ssid + " " + metered;
-        executeShellCommand(cmd);
-        assertWifiMeteredStatus(ssid, metered);
-        assertActiveNetworkMetered(metered);
+        final UiAutomation uiAutomation = getInstrumentation().getUiAutomation();
+        try {
+            uiAutomation.adoptShellPermissionIdentity();
+            final WifiConfiguration currentConfig = getWifiConfiguration(ssid);
+            currentConfig.meteredOverride = metered
+                    ? METERED_OVERRIDE_METERED : METERED_OVERRIDE_NONE;
+            BlockingQueue<Integer> blockingQueue = new LinkedBlockingQueue<>();
+            getWifiManager().save(currentConfig, createActionListener(
+                    blockingQueue, Integer.MAX_VALUE));
+            Integer resultCode = blockingQueue.poll(TIMEOUT_CHANGE_METEREDNESS_MS,
+                    TimeUnit.MILLISECONDS);
+            if (resultCode == null) {
+                fail("Timed out waiting for meteredness to change; ssid=" + ssid
+                        + ", metered=" + metered);
+            } else if (resultCode != Integer.MAX_VALUE) {
+                fail("Error overriding the meteredness; ssid=" + ssid
+                        + ", metered=" + metered + ", error=" + resultCode);
+            }
+            final boolean success = assertActiveNetworkMetered(metered, false /* throwOnFailure */);
+            if (!success) {
+                Log.i(TAG, "Retry connecting to wifi; ssid=" + ssid);
+                blockingQueue = new LinkedBlockingQueue<>();
+                getWifiManager().connect(currentConfig, createActionListener(
+                        blockingQueue, Integer.MAX_VALUE));
+                resultCode = blockingQueue.poll(TIMEOUT_CHANGE_METEREDNESS_MS,
+                        TimeUnit.MILLISECONDS);
+                if (resultCode == null) {
+                    fail("Timed out waiting for wifi to connect; ssid=" + ssid);
+                } else if (resultCode != Integer.MAX_VALUE) {
+                    fail("Error connecting to wifi; ssid=" + ssid
+                            + ", error=" + resultCode);
+                }
+                assertActiveNetworkMetered(metered, true /* throwOnFailure */);
+            }
+        } finally {
+            uiAutomation.dropShellPermissionIdentity();
+        }
     }
 
-    private static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
-        final String result = executeShellCommand("cmd netpolicy list wifi-networks");
-        final String expectedLine = ssid + ";" + expectedMeteredStatus;
-        assertTrue("Expected line: " + expectedLine + "; Actual result: " + result,
-                result.contains(expectedLine));
+    private static WifiConfiguration getWifiConfiguration(String ssid) {
+        final List<String> ssids = new ArrayList<>();
+        for (WifiConfiguration config : getWifiManager().getConfiguredNetworks()) {
+            if (config.SSID.equals(ssid)) {
+                return config;
+            }
+            ssids.add(config.SSID);
+        }
+        fail("Couldn't find the wifi config; ssid=" + ssid
+                + ", all=" + Arrays.toString(ssids.toArray()));
+        return null;
+    }
+
+    private static ActionListener createActionListener(BlockingQueue<Integer> blockingQueue,
+            int successCode) {
+        return new ActionListener() {
+            @Override
+            public void onSuccess() {
+                blockingQueue.offer(successCode);
+            }
+
+            @Override
+            public void onFailure(int reason) {
+                blockingQueue.offer(reason);
+            }
+        };
     }
 
     private static void setCellularMeteredStatus(int subId, boolean metered) throws Exception {
@@ -232,11 +288,11 @@
                 new String[] {ApnSetting.TYPE_MMS_STRING});
         ShellIdentityUtils.invokeMethodWithShellPermissionsNoReturn(getCarrierConfigManager(),
                 (cm) -> cm.overrideConfig(subId, metered ? null : bundle));
-        assertActiveNetworkMetered(metered);
+        assertActiveNetworkMetered(metered, true /* throwOnFailure */);
     }
 
-    // Copied from cts/tests/tests/net/src/android/net/cts/ConnectivityManagerTest.java
-    private static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
+    private static boolean assertActiveNetworkMetered(boolean expectedMeteredStatus,
+            boolean throwOnFailure) throws Exception {
         final CountDownLatch latch = new CountDownLatch(1);
         final NetworkCallback networkCallback = new NetworkCallback() {
             @Override
@@ -253,10 +309,16 @@
         getConnectivityManager().registerDefaultNetworkCallback(networkCallback);
         try {
             if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
-                fail("Timed out waiting for active network metered status to change to "
-                        + expectedMeteredStatus + "; network = "
-                        + getConnectivityManager().getActiveNetwork());
+                final String errorMsg = "Timed out waiting for active network metered status "
+                        + "to change to " + expectedMeteredStatus + "; network = "
+                        + getConnectivityManager().getActiveNetwork();
+                if (throwOnFailure) {
+                    fail(errorMsg);
+                }
+                Log.w(TAG, errorMsg);
+                return false;
             }
+            return true;
         } finally {
             getConnectivityManager().unregisterNetworkCallback(networkCallback);
         }