Merge "Update UpstreamNetworkMonitor to use custom Handlers"
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 7c3e153..30afdc2 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -2923,15 +2923,6 @@
      * @hide
      */
     public void requestNetwork(NetworkRequest request, NetworkCallback networkCallback,
-            int timeoutMs, int legacyType) {
-        requestNetwork(request, networkCallback, timeoutMs, legacyType, getDefaultHandler());
-    }
-
-    /**
-     * Helper function to request a network with a particular legacy type.
-     * @hide
-     */
-    private void requestNetwork(NetworkRequest request, NetworkCallback networkCallback,
             int timeoutMs, int legacyType, Handler handler) {
         CallbackHandler cbHandler = new CallbackHandler(handler);
         NetworkCapabilities nc = request.networkCapabilities;
@@ -3033,7 +3024,7 @@
     public void requestNetwork(NetworkRequest request, NetworkCallback networkCallback,
             int timeoutMs) {
         int legacyType = inferLegacyTypeForNetworkCapabilities(request.networkCapabilities);
-        requestNetwork(request, networkCallback, timeoutMs, legacyType);
+        requestNetwork(request, networkCallback, timeoutMs, legacyType, getDefaultHandler());
     }
 
 
diff --git a/services/core/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitor.java b/services/core/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitor.java
index 63024db..6106093 100644
--- a/services/core/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitor.java
+++ b/services/core/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitor.java
@@ -20,6 +20,8 @@
 import static android.net.ConnectivityManager.TYPE_MOBILE_HIPRI;
 
 import android.content.Context;
+import android.os.Handler;
+import android.os.Looper;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.LinkProperties;
@@ -72,6 +74,7 @@
 
     private final Context mContext;
     private final StateMachine mTarget;
+    private final Handler mHandler;
     private final int mWhat;
     private final HashMap<Network, NetworkState> mNetworkMap = new HashMap<>();
     private ConnectivityManager mCM;
@@ -84,6 +87,7 @@
     public UpstreamNetworkMonitor(Context ctx, StateMachine tgt, int what) {
         mContext = ctx;
         mTarget = tgt;
+        mHandler = mTarget.getHandler();
         mWhat = what;
     }
 
@@ -99,10 +103,10 @@
         final NetworkRequest listenAllRequest = new NetworkRequest.Builder()
                 .clearCapabilities().build();
         mListenAllCallback = new UpstreamNetworkCallback(CALLBACK_LISTEN_ALL);
-        cm().registerNetworkCallback(listenAllRequest, mListenAllCallback);
+        cm().registerNetworkCallback(listenAllRequest, mListenAllCallback, mHandler);
 
         mDefaultNetworkCallback = new UpstreamNetworkCallback(CALLBACK_TRACK_DEFAULT);
-        cm().registerDefaultNetworkCallback(mDefaultNetworkCallback);
+        cm().registerDefaultNetworkCallback(mDefaultNetworkCallback, mHandler);
     }
 
     public void stop() {
@@ -154,7 +158,7 @@
         // Additionally, we log a message to aid in any subsequent debugging.
         Log.d(TAG, "requesting mobile upstream network: " + mobileUpstreamRequest);
 
-        cm().requestNetwork(mobileUpstreamRequest, mMobileNetworkCallback, 0, legacyType);
+        cm().requestNetwork(mobileUpstreamRequest, mMobileNetworkCallback, 0, legacyType, mHandler);
     }
 
     public void releaseMobileNetworkRequest() {
@@ -185,11 +189,13 @@
             case CALLBACK_TRACK_DEFAULT:
                 if (mDefaultNetworkCallback == null) {
                     // The callback was unregistered in the interval between
-                    // ConnectivityService calling onAvailable() and our
-                    // handling of it here on the mTarget.getHandler() thread.
+                    // ConnectivityService enqueueing onAvailable() and our
+                    // handling of it here on the mHandler thread.
+                    //
                     // Clean-up of this network entry is deferred to the
                     // handling of onLost() by other callbacks.
-                    // TODO: change to Log.wtf() after oag/331764 is merged.
+                    //
+                    // These request*() calls can be deleted post oag/339444.
                     return;
                 }
 
@@ -200,11 +206,13 @@
             case CALLBACK_MOBILE_REQUEST:
                 if (mMobileNetworkCallback == null) {
                     // The callback was unregistered in the interval between
-                    // ConnectivityService calling onAvailable() and our
-                    // handling of it here on the mTarget.getHandler() thread.
+                    // ConnectivityService enqueueing onAvailable() and our
+                    // handling of it here on the mHandler thread.
+                    //
                     // Clean-up of this network entry is deferred to the
                     // handling of onLost() by other callbacks.
-                    // TODO: change to Log.wtf() after oag/331764 is merged.
+                    //
+                    // These request*() calls can be deleted post oag/339444.
                     return;
                 }
 
@@ -312,8 +320,9 @@
     }
 
     /**
-     * A NetworkCallback class that relays information of interest to the
-     * tethering master state machine thread for subsequent processing.
+     * A NetworkCallback class that handles information of interest directly
+     * in the thread on which it is invoked. To avoid locking, this MUST be
+     * run on the same thread as the target state machine's handler.
      */
     private class UpstreamNetworkCallback extends NetworkCallback {
         private final int mCallbackType;
@@ -324,22 +333,35 @@
 
         @Override
         public void onAvailable(Network network) {
-            mTarget.getHandler().post(() -> handleAvailable(mCallbackType, network));
+            checkExpectedThread();
+            handleAvailable(mCallbackType, network);
         }
 
         @Override
         public void onCapabilitiesChanged(Network network, NetworkCapabilities newNc) {
-            mTarget.getHandler().post(() -> handleNetCap(network, newNc));
+            checkExpectedThread();
+            handleNetCap(network, newNc);
         }
 
         @Override
         public void onLinkPropertiesChanged(Network network, LinkProperties newLp) {
-            mTarget.getHandler().post(() -> handleLinkProp(network, newLp));
+            checkExpectedThread();
+            handleLinkProp(network, newLp);
         }
 
+        // TODO: Handle onNetworkSuspended();
+        // TODO: Handle onNetworkResumed();
+
         @Override
         public void onLost(Network network) {
-            mTarget.getHandler().post(() -> handleLost(mCallbackType, network));
+            checkExpectedThread();
+            handleLost(mCallbackType, network);
+        }
+
+        private void checkExpectedThread() {
+            if (Looper.myLooper() != mHandler.getLooper()) {
+                Log.wtf(TAG, "Handling callback in unexpected thread.");
+            }
         }
     }
 
diff --git a/tests/net/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitorTest.java b/tests/net/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitorTest.java
index 3ed56df..c72efb0 100644
--- a/tests/net/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitorTest.java
+++ b/tests/net/java/com/android/server/connectivity/tethering/UpstreamNetworkMonitorTest.java
@@ -32,6 +32,8 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.content.Context;
+import android.os.Handler;
+import android.os.Message;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.IConnectivityManager;
@@ -42,6 +44,10 @@
 import android.support.test.filters.SmallTest;
 import android.support.test.runner.AndroidJUnit4;
 
+import com.android.internal.util.State;
+import com.android.internal.util.StateMachine;
+
+import org.junit.After;
 import org.junit.Before;
 import org.junit.runner.RunWith;
 import org.junit.Test;
@@ -49,6 +55,7 @@
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
@@ -63,6 +70,7 @@
     @Mock private Context mContext;
     @Mock private IConnectivityManager mCS;
 
+    private TestStateMachine mSM;
     private TestConnectivityManager mCM;
     private UpstreamNetworkMonitor mUNM;
 
@@ -72,7 +80,15 @@
         reset(mCS);
 
         mCM = spy(new TestConnectivityManager(mContext, mCS));
-        mUNM = new UpstreamNetworkMonitor(null, EVENT_UNM_UPDATE, (ConnectivityManager) mCM);
+        mSM = new TestStateMachine();
+        mUNM = new UpstreamNetworkMonitor(mSM, EVENT_UNM_UPDATE, (ConnectivityManager) mCM);
+    }
+
+    @After public void tearDown() throws Exception {
+        if (mSM != null) {
+            mSM.quit();
+            mSM = null;
+        }
     }
 
     @Test
@@ -139,15 +155,17 @@
 
         mUNM.start();
         verify(mCM, Mockito.times(1)).registerNetworkCallback(
-                any(NetworkRequest.class), any(NetworkCallback.class));
-        verify(mCM, Mockito.times(1)).registerDefaultNetworkCallback(any(NetworkCallback.class));
+                any(NetworkRequest.class), any(NetworkCallback.class), any(Handler.class));
+        verify(mCM, Mockito.times(1)).registerDefaultNetworkCallback(
+                any(NetworkCallback.class), any(Handler.class));
         assertFalse(mUNM.mobileNetworkRequested());
         assertEquals(0, mCM.requested.size());
 
         mUNM.updateMobileRequiresDun(true);
         mUNM.registerMobileNetworkRequest();
         verify(mCM, Mockito.times(1)).requestNetwork(
-                any(NetworkRequest.class), any(NetworkCallback.class), anyInt(), anyInt());
+                any(NetworkRequest.class), any(NetworkCallback.class), anyInt(), anyInt(),
+                any(Handler.class));
 
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
@@ -224,6 +242,7 @@
     }
 
     public static class TestConnectivityManager extends ConnectivityManager {
+        public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
         public Set<NetworkCallback> trackingDefault = new HashSet<>();
         public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
         public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
@@ -234,7 +253,8 @@
         }
 
         boolean hasNoCallbacks() {
-            return trackingDefault.isEmpty() &&
+            return allCallbacks.isEmpty() &&
+                   trackingDefault.isEmpty() &&
                    listening.isEmpty() &&
                    requested.isEmpty() &&
                    legacyTypeMap.isEmpty();
@@ -262,14 +282,23 @@
         }
 
         @Override
-        public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
+        public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
+            assertFalse(allCallbacks.containsKey(cb));
+            allCallbacks.put(cb, h);
             assertFalse(requested.containsKey(cb));
             requested.put(cb, req);
         }
 
         @Override
+        public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
+            fail("Should never be called.");
+        }
+
+        @Override
         public void requestNetwork(NetworkRequest req, NetworkCallback cb,
-                int timeoutMs, int legacyType) {
+                int timeoutMs, int legacyType, Handler h) {
+            assertFalse(allCallbacks.containsKey(cb));
+            allCallbacks.put(cb, h);
             assertFalse(requested.containsKey(cb));
             requested.put(cb, req);
             assertFalse(legacyTypeMap.containsKey(cb));
@@ -279,18 +308,32 @@
         }
 
         @Override
-        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
+        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
+            assertFalse(allCallbacks.containsKey(cb));
+            allCallbacks.put(cb, h);
             assertFalse(listening.containsKey(cb));
             listening.put(cb, req);
         }
 
         @Override
-        public void registerDefaultNetworkCallback(NetworkCallback cb) {
+        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
+            fail("Should never be called.");
+        }
+
+        @Override
+        public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
+            assertFalse(allCallbacks.containsKey(cb));
+            allCallbacks.put(cb, h);
             assertFalse(trackingDefault.contains(cb));
             trackingDefault.add(cb);
         }
 
         @Override
+        public void registerDefaultNetworkCallback(NetworkCallback cb) {
+            fail("Should never be called.");
+        }
+
+        @Override
         public void unregisterNetworkCallback(NetworkCallback cb) {
             if (trackingDefault.contains(cb)) {
                 trackingDefault.remove(cb);
@@ -302,10 +345,35 @@
             } else {
                 fail("Unexpected callback removed");
             }
+            allCallbacks.remove(cb);
 
+            assertFalse(allCallbacks.containsKey(cb));
             assertFalse(trackingDefault.contains(cb));
             assertFalse(listening.containsKey(cb));
             assertFalse(requested.containsKey(cb));
         }
     }
+
+    public static class TestStateMachine extends StateMachine {
+        public final ArrayList<Message> messages = new ArrayList<>();
+        private final State mLoggingState = new LoggingState();
+
+        class LoggingState extends State {
+            @Override public void enter() { messages.clear(); }
+
+            @Override public void exit() { messages.clear(); }
+
+            @Override public boolean processMessage(Message msg) {
+                messages.add(msg);
+                return true;
+            }
+        }
+
+        public TestStateMachine() {
+            super("UpstreamNetworkMonitor.TestStateMachine");
+            addState(mLoggingState);
+            setInitialState(mLoggingState);
+            super.start();
+        }
+    }
 }