blob: b51b2771db168eab44b8d7d5180d92a3376ef7f2 [file] [log] [blame]
/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity;
import static android.content.pm.UserInfo.FLAG_ADMIN;
import static android.content.pm.UserInfo.FLAG_MANAGED_PROFILE;
import static android.content.pm.UserInfo.FLAG_PRIMARY;
import static android.content.pm.UserInfo.FLAG_RESTRICTED;
import static org.mockito.AdditionalMatchers.*;
import static org.mockito.Mockito.*;
import android.annotation.UserIdInt;
import android.app.AppOpsManager;
import android.app.NotificationManager;
import android.content.Context;
import android.content.pm.PackageManager;
import android.content.pm.UserInfo;
import android.net.NetworkInfo.DetailedState;
import android.net.UidRange;
import android.os.INetworkManagementService;
import android.os.Looper;
import android.os.UserHandle;
import android.os.UserManager;
import android.test.AndroidTestCase;
import android.test.suitebuilder.annotation.SmallTest;
import android.util.ArrayMap;
import android.util.ArraySet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/**
* Tests for {@link Vpn}.
*
* Build, install and run with:
* runtest --path src/com/android/server/connectivity/VpnTest.java
*/
public class VpnTest extends AndroidTestCase {
private static final String TAG = "VpnTest";
// Mock users
static final UserInfo primaryUser = new UserInfo(27, "Primary", FLAG_ADMIN | FLAG_PRIMARY);
static final UserInfo secondaryUser = new UserInfo(15, "Secondary", FLAG_ADMIN);
static final UserInfo restrictedProfileA = new UserInfo(40, "RestrictedA", FLAG_RESTRICTED);
static final UserInfo restrictedProfileB = new UserInfo(42, "RestrictedB", FLAG_RESTRICTED);
static final UserInfo managedProfileA = new UserInfo(45, "ManagedA", FLAG_MANAGED_PROFILE);
static {
restrictedProfileA.restrictedProfileParentId = primaryUser.id;
restrictedProfileB.restrictedProfileParentId = secondaryUser.id;
managedProfileA.profileGroupId = primaryUser.id;
}
/**
* Names and UIDs for some fake packages. Important points:
* - UID is ordered increasing.
* - One pair of packages have consecutive UIDs.
*/
static final String[] PKGS = {"com.example", "org.example", "net.example", "web.vpn"};
static final int[] PKG_UIDS = {66, 77, 78, 400};
// Mock packages
static final Map<String, Integer> mPackages = new ArrayMap<>();
static {
for (int i = 0; i < PKGS.length; i++) {
mPackages.put(PKGS[i], PKG_UIDS[i]);
}
}
@Mock private Context mContext;
@Mock private UserManager mUserManager;
@Mock private PackageManager mPackageManager;
@Mock private INetworkManagementService mNetService;
@Mock private AppOpsManager mAppOps;
@Mock private NotificationManager mNotificationManager;
@Override
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
when(mContext.getPackageManager()).thenReturn(mPackageManager);
setMockedPackages(mPackages);
when(mContext.getPackageName()).thenReturn(Vpn.class.getPackage().getName());
when(mContext.getSystemService(eq(Context.USER_SERVICE))).thenReturn(mUserManager);
when(mContext.getSystemService(eq(Context.APP_OPS_SERVICE))).thenReturn(mAppOps);
when(mContext.getSystemService(eq(Context.NOTIFICATION_SERVICE)))
.thenReturn(mNotificationManager);
doNothing().when(mNetService).registerObserver(any());
}
@SmallTest
public void testRestrictedProfilesAreAddedToVpn() {
setMockedUsers(primaryUser, secondaryUser, restrictedProfileA, restrictedProfileB);
final Vpn vpn = spyVpn(primaryUser.id);
final Set<UidRange> ranges = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
null, null);
assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
UidRange.createForUser(primaryUser.id),
UidRange.createForUser(restrictedProfileA.id)
})), ranges);
}
@SmallTest
public void testManagedProfilesAreNotAddedToVpn() {
setMockedUsers(primaryUser, managedProfileA);
final Vpn vpn = spyVpn(primaryUser.id);
final Set<UidRange> ranges = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
null, null);
assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
UidRange.createForUser(primaryUser.id)
})), ranges);
}
@SmallTest
public void testAddUserToVpnOnlyAddsOneUser() {
setMockedUsers(primaryUser, restrictedProfileA, managedProfileA);
final Vpn vpn = spyVpn(primaryUser.id);
final Set<UidRange> ranges = new ArraySet<>();
vpn.addUserToRanges(ranges, primaryUser.id, null, null);
assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
UidRange.createForUser(primaryUser.id)
})), ranges);
}
@SmallTest
public void testUidWhiteAndBlacklist() throws Exception {
final Vpn vpn = spyVpn(primaryUser.id);
final UidRange user = UidRange.createForUser(primaryUser.id);
final String[] packages = {PKGS[0], PKGS[1], PKGS[2]};
// Whitelist
final Set<UidRange> allow = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
Arrays.asList(packages), null);
assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
new UidRange(user.start + PKG_UIDS[0], user.start + PKG_UIDS[0]),
new UidRange(user.start + PKG_UIDS[1], user.start + PKG_UIDS[2])
})), allow);
// Blacklist
final Set<UidRange> disallow = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
null, Arrays.asList(packages));
assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
new UidRange(user.start, user.start + PKG_UIDS[0] - 1),
new UidRange(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[1] - 1),
/* Empty range between UIDS[1] and UIDS[2], should be excluded, */
new UidRange(user.start + PKG_UIDS[2] + 1, user.stop)
})), disallow);
}
@SmallTest
public void testLockdownChangingPackage() throws Exception {
final Vpn vpn = spyVpn(primaryUser.id);
final UidRange user = UidRange.createForUser(primaryUser.id);
// Default state.
assertUnblocked(vpn, user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
// Set always-on without lockdown.
assertTrue(vpn.setAlwaysOnPackage(PKGS[1], false));
assertUnblocked(vpn, user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
// Set always-on with lockdown.
assertTrue(vpn.setAlwaysOnPackage(PKGS[1], true));
verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
new UidRange(user.start, user.start + PKG_UIDS[1] - 1),
new UidRange(user.start + PKG_UIDS[1] + 1, user.stop)
}));
assertBlocked(vpn, user.start + PKG_UIDS[0], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
assertUnblocked(vpn, user.start + PKG_UIDS[1]);
// Switch to another app.
assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true));
verify(mNetService).setAllowOnlyVpnForUids(eq(false), aryEq(new UidRange[] {
new UidRange(user.start, user.start + PKG_UIDS[1] - 1),
new UidRange(user.start + PKG_UIDS[1] + 1, user.stop)
}));
verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
new UidRange(user.start, user.start + PKG_UIDS[3] - 1),
new UidRange(user.start + PKG_UIDS[3] + 1, user.stop)
}));
assertBlocked(vpn, user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2]);
assertUnblocked(vpn, user.start + PKG_UIDS[3]);
}
@SmallTest
public void testLockdownAddingAProfile() throws Exception {
final Vpn vpn = spyVpn(primaryUser.id);
setMockedUsers(primaryUser);
// Make a copy of the restricted profile, as we're going to mark it deleted halfway through.
final UserInfo tempProfile = new UserInfo(restrictedProfileA.id, restrictedProfileA.name,
restrictedProfileA.flags);
tempProfile.restrictedProfileParentId = primaryUser.id;
final UidRange user = UidRange.createForUser(primaryUser.id);
final UidRange profile = UidRange.createForUser(tempProfile.id);
// Set lockdown.
assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true));
verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
new UidRange(user.start, user.start + PKG_UIDS[3] - 1),
new UidRange(user.start + PKG_UIDS[3] + 1, user.stop)
}));
// Verify restricted user isn't affected at first.
assertUnblocked(vpn, profile.start + PKG_UIDS[0]);
// Add the restricted user.
setMockedUsers(primaryUser, tempProfile);
vpn.onUserAdded(tempProfile.id);
verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
new UidRange(profile.start, profile.start + PKG_UIDS[3] - 1),
new UidRange(profile.start + PKG_UIDS[3] + 1, profile.stop)
}));
// Remove the restricted user.
tempProfile.partial = true;
vpn.onUserRemoved(tempProfile.id);
verify(mNetService).setAllowOnlyVpnForUids(eq(false), aryEq(new UidRange[] {
new UidRange(profile.start, profile.start + PKG_UIDS[3] - 1),
new UidRange(profile.start + PKG_UIDS[3] + 1, profile.stop)
}));
}
@SmallTest
public void testNotificationShownForAlwaysOnApp() {
final Vpn vpn = spyVpn(primaryUser.id);
final InOrder order = inOrder(vpn);
setMockedUsers(primaryUser);
// Don't show a notification for regular disconnected states.
vpn.updateState(DetailedState.DISCONNECTED, TAG);
order.verify(vpn).updateAlwaysOnNotificationInternal(false);
// Start showing a notification for disconnected once always-on.
vpn.setAlwaysOnPackage(PKGS[0], false);
order.verify(vpn).updateAlwaysOnNotificationInternal(true);
// Stop showing the notification once connected.
vpn.updateState(DetailedState.CONNECTED, TAG);
order.verify(vpn).updateAlwaysOnNotificationInternal(false);
// Show the notification if we disconnect again.
vpn.updateState(DetailedState.DISCONNECTED, TAG);
order.verify(vpn).updateAlwaysOnNotificationInternal(true);
// Notification should be cleared after unsetting always-on package.
vpn.setAlwaysOnPackage(null, false);
order.verify(vpn).updateAlwaysOnNotificationInternal(false);
}
/**
* Mock some methods of vpn object.
*/
private Vpn spyVpn(@UserIdInt int userId) {
final Vpn vpn = spy(new Vpn(Looper.myLooper(), mContext, mNetService, userId));
// Block calls to the NotificationManager or PendingIntent#getActivity.
doNothing().when(vpn).updateAlwaysOnNotificationInternal(anyBoolean());
return vpn;
}
private static void assertBlocked(Vpn vpn, int... uids) {
for (int uid : uids) {
assertTrue("Uid " + uid + " should be blocked", vpn.isBlockingUid(uid));
}
}
private static void assertUnblocked(Vpn vpn, int... uids) {
for (int uid : uids) {
assertFalse("Uid " + uid + " should not be blocked", vpn.isBlockingUid(uid));
}
}
/**
* Populate {@link #mUserManager} with a list of fake users.
*/
private void setMockedUsers(UserInfo... users) {
final Map<Integer, UserInfo> userMap = new ArrayMap<>();
for (UserInfo user : users) {
userMap.put(user.id, user);
}
/**
* @see UserManagerService#getUsers(boolean)
*/
doAnswer(invocation -> {
final boolean excludeDying = (boolean) invocation.getArguments()[0];
final ArrayList<UserInfo> result = new ArrayList<>(users.length);
for (UserInfo ui : users) {
if (!excludeDying || (ui.isEnabled() && !ui.partial)) {
result.add(ui);
}
}
return result;
}).when(mUserManager).getUsers(anyBoolean());
doAnswer(invocation -> {
final int id = (int) invocation.getArguments()[0];
return userMap.get(id);
}).when(mUserManager).getUserInfo(anyInt());
doAnswer(invocation -> {
final int id = (int) invocation.getArguments()[0];
return (userMap.get(id).flags & UserInfo.FLAG_ADMIN) != 0;
}).when(mUserManager).canHaveRestrictedProfile(anyInt());
}
/**
* Populate {@link #mPackageManager} with a fake packageName-to-UID mapping.
*/
private void setMockedPackages(final Map<String, Integer> packages) {
try {
doAnswer(invocation -> {
final String appName = (String) invocation.getArguments()[0];
final int userId = (int) invocation.getArguments()[1];
return UserHandle.getUid(userId, packages.get(appName));
}).when(mPackageManager).getPackageUidAsUser(anyString(), anyInt());
} catch (Exception e) {
}
}
}