Merge "Fixed testPrivateDnsPolicy on headless system user mode." into sc-dev
diff --git a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DeviceOwnerHelper.java b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DeviceOwnerHelper.java
index 0e7ce36..ed99c17d 100644
--- a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DeviceOwnerHelper.java
+++ b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DeviceOwnerHelper.java
@@ -24,6 +24,7 @@
 import static com.android.bedstead.dpmwrapper.Utils.EXTRA_METHOD;
 import static com.android.bedstead.dpmwrapper.Utils.EXTRA_NUMBER_ARGS;
 import static com.android.bedstead.dpmwrapper.Utils.VERBOSE;
+import static com.android.bedstead.dpmwrapper.Utils.callOnHandlerThread;
 import static com.android.bedstead.dpmwrapper.Utils.isHeadlessSystemUser;
 
 import android.annotation.Nullable;
@@ -76,7 +77,7 @@
             Log.d(TAG, "runManagerMethod(): userId=" + context.getUserId()
                     + ", intent=" + intent.getAction() + ", class=" + className
                     + ", methodName=" + methodName + ", numberArgs=" + numberArgs);
-            Object[] args = null;
+            final Object[] args;
             Class<?>[] parameterTypes = null;
             if (numberArgs > 0) {
                 args = new Object[numberArgs];
@@ -88,6 +89,8 @@
                 Log.d(TAG, "runManagerMethod(): args=" + Arrays.toString(args) + ", types="
                         + Arrays.toString(parameterTypes));
 
+            } else {
+                args = null;
             }
             Class<?> managerClass = Class.forName(className);
             Method method = findMethod(managerClass, methodName, parameterTypes);
@@ -99,8 +102,8 @@
             Object manager = managerClass.equals(DevicePolicyManager.class)
                     ? receiver.getManager(context)
                     : context.getSystemService(managerClass);
-
-            Object result = method.invoke(manager, args);
+            // Must handle in a separate thread as some APIs will fail when called from main's
+            Object result = callOnHandlerThread(() -> method.invoke(manager, args));
 
             if (VERBOSE) {
                 // Some results - like network logging events - are quite large
diff --git a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DevicePolicyManagerWrapper.java b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DevicePolicyManagerWrapper.java
index 484a9e4..377439e 100644
--- a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DevicePolicyManagerWrapper.java
+++ b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/DevicePolicyManagerWrapper.java
@@ -172,6 +172,11 @@
             doAnswer(answer).when(spy).retrievePreRebootSecurityLogs(any());
             doAnswer(answer).when(spy).getLastNetworkLogRetrievalTime();
 
+            // Used by PrivateDnsPolicyTest
+            doAnswer(answer).when(spy).getGlobalPrivateDnsHost(any());
+            doAnswer(answer).when(spy).getGlobalPrivateDnsMode(any());
+            doAnswer(answer).when(spy).setGlobalPrivateDnsModeSpecifiedHost(any(), any());
+
             // TODO(b/176993670): add more methods below as tests are converted
         } catch (Exception e) {
             // Should never happen, but needs to be catch as some methods declare checked exceptions
diff --git a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/TestAppSystemServiceFactory.java b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/TestAppSystemServiceFactory.java
index f5c618b..7bb9084 100644
--- a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/TestAppSystemServiceFactory.java
+++ b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/TestAppSystemServiceFactory.java
@@ -24,6 +24,7 @@
 import static com.android.bedstead.dpmwrapper.Utils.EXTRA_METHOD;
 import static com.android.bedstead.dpmwrapper.Utils.EXTRA_NUMBER_ARGS;
 import static com.android.bedstead.dpmwrapper.Utils.VERBOSE;
+import static com.android.bedstead.dpmwrapper.Utils.getHandler;
 
 import android.annotation.Nullable;
 import android.app.admin.DeviceAdminReceiver;
@@ -37,8 +38,6 @@
 import android.content.pm.PackageManager.NameNotFoundException;
 import android.net.wifi.WifiManager;
 import android.os.Bundle;
-import android.os.Handler;
-import android.os.HandlerThread;
 import android.os.UserHandle;
 import android.os.UserManager;
 import android.util.Log;
@@ -72,10 +71,6 @@
     // 6 minutes for network monitoring events.
     private static final long TIMEOUT_MS = TimeUnit.MINUTES.toMillis(10);
 
-    private static final HandlerThread HANDLER_THREAD = new HandlerThread(TAG + "HandlerThread");
-
-    private static Handler sHandler;
-
     // Caches whether the package declares the required receiver (otherwise each test would be
     // querying package manager, which is expensive)
     private static final HashMap<String, Boolean> sHasRequiredReceiver = new HashMap<>();
@@ -164,12 +159,6 @@
             return manager;
         }
 
-        if (sHandler == null) {
-            Log.i(TAG, "Starting handler thread " + HANDLER_THREAD);
-            HANDLER_THREAD.start();
-            sHandler = new Handler(HANDLER_THREAD.getLooper());
-        }
-
         String receiverClassName = receiverClass.getName();
         final String wrappedClassName = wrappedClass.getName();
         if (VERBOSE) {
@@ -222,7 +211,7 @@
                         + "grantDpmWrapper() (for user " + userId + ") in the host-side test?");
             }
             context.sendOrderedBroadcastAsUser(intent,
-                    UserHandle.SYSTEM, /* permission= */ null, myReceiver, sHandler,
+                    UserHandle.SYSTEM, /* permission= */ null, myReceiver, getHandler(),
                     RESULT_NOT_SENT_TO_ANY_RECEIVER, /* initialData= */ null,
                     /* initialExtras= */ null);
 
diff --git a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/Utils.java b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/Utils.java
index 5fc2518..03b8963 100644
--- a/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/Utils.java
+++ b/common/device-side/bedstead/dpmwrapper/src/main/java/com/android/bedstead/dpmwrapper/Utils.java
@@ -20,16 +20,28 @@
 import android.content.Intent;
 import android.content.IntentFilter;
 import android.os.Bundle;
+import android.os.Handler;
+import android.os.HandlerThread;
 import android.os.UserHandle;
 import android.os.UserManager;
+import android.util.Log;
+
+import com.android.internal.annotations.GuardedBy;
 
 import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * Generic helpers.
  */
 public final class Utils {
 
+    private static final String TAG = "DpmWrapperUtils";
+
     static final boolean VERBOSE = false;
 
     static final int MY_USER_ID = UserHandle.myUserId();
@@ -41,6 +53,14 @@
     static final String EXTRA_NUMBER_ARGS = "number_args";
     static final String EXTRA_ARG_PREFIX = "arg_";
 
+    private static final Object LOCK = new Object();
+
+    @GuardedBy("LOCK")
+    private static HandlerThread sHandlerThread;
+
+    @GuardedBy("LOCK")
+    private static Handler sHandler;
+
     static boolean isHeadlessSystemUser() {
         return UserManager.isHeadlessSystemUserMode() && MY_USER_ID == UserHandle.USER_SYSTEM;
     }
@@ -65,6 +85,49 @@
         return builder.append(']').toString();
     }
 
+    static Handler getHandler() {
+        synchronized (LOCK) {
+            if (sHandler == null) {
+                sHandlerThread = new HandlerThread("DpmWrapperHandlerThread");
+                Log.i(TAG, "Starting handler thread " + sHandlerThread);
+                sHandlerThread.start();
+                sHandler = new Handler(sHandlerThread.getLooper());
+            }
+        }
+        return sHandler;
+    }
+
+    static <T> T callOnHandlerThread(Callable<T> callable) throws Exception {
+        if (VERBOSE) Log.v(TAG, "callOnHandlerThread(): called from " + Thread.currentThread());
+
+        final CountDownLatch latch = new CountDownLatch(1);
+        final AtomicReference<T> returnRef = new AtomicReference<>();
+        final AtomicReference<Exception> exceptionRef = new AtomicReference<>();
+
+        getHandler().post(() -> {
+            Log.d(TAG, "Calling callable on handler thread " + Thread.currentThread());
+            try {
+                T result = callable.call();
+                if (VERBOSE) Log.v(TAG, "Got result: "  + result);
+                returnRef.set(result);
+            } catch (Exception e) {
+                Log.e(TAG, "Got exception: "  + e);
+                exceptionRef.set(e);
+            } finally {
+                latch.countDown();
+            }
+        });
+
+        if (!latch.await(50, TimeUnit.SECONDS)) {
+            throw new TimeoutException("didn't get result in 50 seconds");
+        }
+
+        Exception exception = exceptionRef.get();
+        if (exception != null) throw exception;
+
+        return returnRef.get();
+    }
+
     /**
      * Gets a more detailed description of an intent (for example, including extras).
      */
diff --git a/hostsidetests/devicepolicy/app/DeviceOwner/src/com/android/cts/deviceowner/PrivateDnsPolicyTest.java b/hostsidetests/devicepolicy/app/DeviceOwner/src/com/android/cts/deviceowner/PrivateDnsPolicyTest.java
index c0b0eab..1847889 100644
--- a/hostsidetests/devicepolicy/app/DeviceOwner/src/com/android/cts/deviceowner/PrivateDnsPolicyTest.java
+++ b/hostsidetests/devicepolicy/app/DeviceOwner/src/com/android/cts/deviceowner/PrivateDnsPolicyTest.java
@@ -16,7 +16,6 @@
 
 package com.android.cts.deviceowner;
 
-import static android.app.admin.DevicePolicyManager.PRIVATE_DNS_MODE_OFF;
 import static android.app.admin.DevicePolicyManager.PRIVATE_DNS_MODE_OPPORTUNISTIC;
 import static android.app.admin.DevicePolicyManager.PRIVATE_DNS_MODE_PROVIDER_HOSTNAME;
 
@@ -60,11 +59,10 @@
     }
 
     private void setUserRestriction(String restriction, boolean add) {
-        DevicePolicyManager dpm = mContext.getSystemService(DevicePolicyManager.class);
         if (add) {
-            dpm.addUserRestriction(getWho(), restriction);
+            mDevicePolicyManager.addUserRestriction(getWho(), restriction);
         } else {
-            dpm.clearUserRestriction(getWho(), restriction);
+            mDevicePolicyManager.clearUserRestriction(getWho(), restriction);
         }
     }