blob: c72efb0344d1a8045e39ce9f55ca21c4fe648ad7 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.tethering;
import static android.net.ConnectivityManager.TYPE_MOBILE_DUN;
import static android.net.ConnectivityManager.TYPE_MOBILE_HIPRI;
import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import android.content.Context;
import android.os.Handler;
import android.os.Message;
import android.net.ConnectivityManager;
import android.net.ConnectivityManager.NetworkCallback;
import android.net.IConnectivityManager;
import android.net.Network;
import android.net.NetworkCapabilities;
import android.net.NetworkRequest;
import android.support.test.filters.SmallTest;
import android.support.test.runner.AndroidJUnit4;
import com.android.internal.util.State;
import com.android.internal.util.StateMachine;
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@RunWith(AndroidJUnit4.class)
@SmallTest
public class UpstreamNetworkMonitorTest {
private static final int EVENT_UNM_UPDATE = 1;
@Mock private Context mContext;
@Mock private IConnectivityManager mCS;
private TestStateMachine mSM;
private TestConnectivityManager mCM;
private UpstreamNetworkMonitor mUNM;
@Before public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
reset(mContext);
reset(mCS);
mCM = spy(new TestConnectivityManager(mContext, mCS));
mSM = new TestStateMachine();
mUNM = new UpstreamNetworkMonitor(mSM, EVENT_UNM_UPDATE, (ConnectivityManager) mCM);
}
@After public void tearDown() throws Exception {
if (mSM != null) {
mSM.quit();
mSM = null;
}
}
@Test
public void testDoesNothingBeforeStarted() {
assertTrue(mCM.hasNoCallbacks());
assertFalse(mUNM.mobileNetworkRequested());
mUNM.updateMobileRequiresDun(true);
assertTrue(mCM.hasNoCallbacks());
mUNM.updateMobileRequiresDun(false);
assertTrue(mCM.hasNoCallbacks());
}
@Test
public void testDefaultNetworkIsTracked() throws Exception {
assertEquals(0, mCM.trackingDefault.size());
mUNM.start();
assertEquals(1, mCM.trackingDefault.size());
mUNM.stop();
assertTrue(mCM.hasNoCallbacks());
}
@Test
public void testListensForAllNetworks() throws Exception {
assertTrue(mCM.listening.isEmpty());
mUNM.start();
assertFalse(mCM.listening.isEmpty());
assertTrue(mCM.isListeningForAll());
mUNM.stop();
assertTrue(mCM.hasNoCallbacks());
}
@Test
public void testRequestsMobileNetwork() throws Exception {
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.start();
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.updateMobileRequiresDun(false);
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.registerMobileNetworkRequest();
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
assertFalse(mCM.isDunRequested());
mUNM.stop();
assertFalse(mUNM.mobileNetworkRequested());
assertTrue(mCM.hasNoCallbacks());
}
@Test
public void testDuplicateMobileRequestsIgnored() throws Exception {
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.start();
verify(mCM, Mockito.times(1)).registerNetworkCallback(
any(NetworkRequest.class), any(NetworkCallback.class), any(Handler.class));
verify(mCM, Mockito.times(1)).registerDefaultNetworkCallback(
any(NetworkCallback.class), any(Handler.class));
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.updateMobileRequiresDun(true);
mUNM.registerMobileNetworkRequest();
verify(mCM, Mockito.times(1)).requestNetwork(
any(NetworkRequest.class), any(NetworkCallback.class), anyInt(), anyInt(),
any(Handler.class));
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
assertTrue(mCM.isDunRequested());
// Try a few things that must not result in any state change.
mUNM.registerMobileNetworkRequest();
mUNM.updateMobileRequiresDun(true);
mUNM.registerMobileNetworkRequest();
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
assertTrue(mCM.isDunRequested());
mUNM.stop();
verify(mCM, times(3)).unregisterNetworkCallback(any(NetworkCallback.class));
verifyNoMoreInteractions(mCM);
}
@Test
public void testRequestsDunNetwork() throws Exception {
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.start();
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.updateMobileRequiresDun(true);
assertFalse(mUNM.mobileNetworkRequested());
assertEquals(0, mCM.requested.size());
mUNM.registerMobileNetworkRequest();
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
assertTrue(mCM.isDunRequested());
mUNM.stop();
assertFalse(mUNM.mobileNetworkRequested());
assertTrue(mCM.hasNoCallbacks());
}
@Test
public void testUpdateMobileRequiresDun() throws Exception {
mUNM.start();
// Test going from no-DUN to DUN correctly re-registers callbacks.
mUNM.updateMobileRequiresDun(false);
mUNM.registerMobileNetworkRequest();
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
assertFalse(mCM.isDunRequested());
mUNM.updateMobileRequiresDun(true);
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
assertTrue(mCM.isDunRequested());
// Test going from DUN to no-DUN correctly re-registers callbacks.
mUNM.updateMobileRequiresDun(false);
assertTrue(mUNM.mobileNetworkRequested());
assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
assertFalse(mCM.isDunRequested());
mUNM.stop();
assertFalse(mUNM.mobileNetworkRequested());
}
private void assertUpstreamTypeRequested(int upstreamType) throws Exception {
assertEquals(1, mCM.requested.size());
assertEquals(1, mCM.legacyTypeMap.size());
assertEquals(Integer.valueOf(upstreamType),
mCM.legacyTypeMap.values().iterator().next());
}
public static class TestConnectivityManager extends ConnectivityManager {
public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
public Set<NetworkCallback> trackingDefault = new HashSet<>();
public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
public Map<NetworkCallback, Integer> legacyTypeMap = new HashMap<>();
public TestConnectivityManager(Context ctx, IConnectivityManager svc) {
super(ctx, svc);
}
boolean hasNoCallbacks() {
return allCallbacks.isEmpty() &&
trackingDefault.isEmpty() &&
listening.isEmpty() &&
requested.isEmpty() &&
legacyTypeMap.isEmpty();
}
boolean isListeningForAll() {
final NetworkCapabilities empty = new NetworkCapabilities();
empty.clearAll();
for (NetworkRequest req : listening.values()) {
if (req.networkCapabilities.equalRequestableCapabilities(empty)) {
return true;
}
}
return false;
}
boolean isDunRequested() {
for (NetworkRequest req : requested.values()) {
if (req.networkCapabilities.hasCapability(NET_CAPABILITY_DUN)) {
return true;
}
}
return false;
}
@Override
public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
assertFalse(allCallbacks.containsKey(cb));
allCallbacks.put(cb, h);
assertFalse(requested.containsKey(cb));
requested.put(cb, req);
}
@Override
public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
fail("Should never be called.");
}
@Override
public void requestNetwork(NetworkRequest req, NetworkCallback cb,
int timeoutMs, int legacyType, Handler h) {
assertFalse(allCallbacks.containsKey(cb));
allCallbacks.put(cb, h);
assertFalse(requested.containsKey(cb));
requested.put(cb, req);
assertFalse(legacyTypeMap.containsKey(cb));
if (legacyType != ConnectivityManager.TYPE_NONE) {
legacyTypeMap.put(cb, legacyType);
}
}
@Override
public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
assertFalse(allCallbacks.containsKey(cb));
allCallbacks.put(cb, h);
assertFalse(listening.containsKey(cb));
listening.put(cb, req);
}
@Override
public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
fail("Should never be called.");
}
@Override
public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
assertFalse(allCallbacks.containsKey(cb));
allCallbacks.put(cb, h);
assertFalse(trackingDefault.contains(cb));
trackingDefault.add(cb);
}
@Override
public void registerDefaultNetworkCallback(NetworkCallback cb) {
fail("Should never be called.");
}
@Override
public void unregisterNetworkCallback(NetworkCallback cb) {
if (trackingDefault.contains(cb)) {
trackingDefault.remove(cb);
} else if (listening.containsKey(cb)) {
listening.remove(cb);
} else if (requested.containsKey(cb)) {
requested.remove(cb);
legacyTypeMap.remove(cb);
} else {
fail("Unexpected callback removed");
}
allCallbacks.remove(cb);
assertFalse(allCallbacks.containsKey(cb));
assertFalse(trackingDefault.contains(cb));
assertFalse(listening.containsKey(cb));
assertFalse(requested.containsKey(cb));
}
}
public static class TestStateMachine extends StateMachine {
public final ArrayList<Message> messages = new ArrayList<>();
private final State mLoggingState = new LoggingState();
class LoggingState extends State {
@Override public void enter() { messages.clear(); }
@Override public void exit() { messages.clear(); }
@Override public boolean processMessage(Message msg) {
messages.add(msg);
return true;
}
}
public TestStateMachine() {
super("UpstreamNetworkMonitor.TestStateMachine");
addState(mLoggingState);
setInitialState(mLoggingState);
super.start();
}
}
}