blob: 2564daeaa1c0f7f0d0e3523ed81641d0da0f1d4d [file] [log] [blame]
/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.net;
import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
import static com.android.internal.util.Preconditions.checkArgument;
import android.app.usage.NetworkStatsManager;
import android.net.DataUsageRequest;
import android.net.NetworkStats;
import android.net.NetworkStatsHistory;
import android.net.NetworkTemplate;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.IBinder;
import android.os.Looper;
import android.os.Message;
import android.os.Messenger;
import android.os.Process;
import android.os.RemoteException;
import android.util.ArrayMap;
import android.util.Slog;
import android.util.SparseArray;
import com.android.internal.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Manages observers of {@link NetworkStats}. Allows observers to be notified when
* data usage has been reported in {@link NetworkStatsService}. An observer can set
* a threshold of how much data it cares about to be notified.
*/
class NetworkStatsObservers {
private static final String TAG = "NetworkStatsObservers";
private static final boolean LOGV = false;
private static final int MSG_REGISTER = 1;
private static final int MSG_UNREGISTER = 2;
private static final int MSG_UPDATE_STATS = 3;
// All access to this map must be done from the handler thread.
// indexed by DataUsageRequest#requestId
private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
// Sequence number of DataUsageRequests
private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
// Lazily instantiated when an observer is registered.
private volatile Handler mHandler;
/**
* Creates a wrapper that contains the caller context and a normalized request.
* The request should be returned to the caller app, and the wrapper should be sent to this
* object through #addObserver by the service handler.
*
* <p>It will register the observer asynchronously, so it is safe to call from any thread.
*
* @return the normalized request wrapped within {@link RequestInfo}.
*/
public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
DataUsageRequest request = buildRequest(inputRequest);
RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
accessLevel);
if (LOGV) Slog.v(TAG, "Registering observer for " + request);
getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
return request;
}
/**
* Unregister a data usage observer.
*
* <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
*/
public void unregister(DataUsageRequest request, int callingUid) {
getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
request));
}
/**
* Updates data usage statistics of registered observers and notifies if limits are reached.
*
* <p>It will update stats asynchronously, so it is safe to call from any thread.
*/
public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
ArrayMap<String, NetworkIdentitySet> activeIfaces,
ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
long currentTime) {
StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
activeUidIfaces, currentTime);
getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
}
private Handler getHandler() {
if (mHandler == null) {
synchronized (this) {
if (mHandler == null) {
if (LOGV) Slog.v(TAG, "Creating handler");
mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
}
}
}
return mHandler;
}
@VisibleForTesting
protected Looper getHandlerLooperLocked() {
HandlerThread handlerThread = new HandlerThread(TAG);
handlerThread.start();
return handlerThread.getLooper();
}
private Handler.Callback mHandlerCallback = new Handler.Callback() {
@Override
public boolean handleMessage(Message msg) {
switch (msg.what) {
case MSG_REGISTER: {
handleRegister((RequestInfo) msg.obj);
return true;
}
case MSG_UNREGISTER: {
handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
return true;
}
case MSG_UPDATE_STATS: {
handleUpdateStats((StatsContext) msg.obj);
return true;
}
default: {
return false;
}
}
}
};
/**
* Adds a {@link RequestInfo} as an observer.
* Should only be called from the handler thread otherwise there will be a race condition
* on mDataUsageRequests.
*/
private void handleRegister(RequestInfo requestInfo) {
mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
}
/**
* Removes a {@link DataUsageRequest} if the calling uid is authorized.
* Should only be called from the handler thread otherwise there will be a race condition
* on mDataUsageRequests.
*/
private void handleUnregister(DataUsageRequest request, int callingUid) {
RequestInfo requestInfo;
requestInfo = mDataUsageRequests.get(request.requestId);
if (requestInfo == null) {
if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
return;
}
if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
return;
}
if (LOGV) Slog.v(TAG, "Unregistering " + request);
mDataUsageRequests.remove(request.requestId);
requestInfo.unlinkDeathRecipient();
requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
}
private void handleUpdateStats(StatsContext statsContext) {
if (mDataUsageRequests.size() == 0) {
return;
}
for (int i = 0; i < mDataUsageRequests.size(); i++) {
RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
requestInfo.updateStats(statsContext);
}
}
private DataUsageRequest buildRequest(DataUsageRequest request) {
// Cap the minimum threshold to a safe default to avoid too many callbacks
long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
if (thresholdInBytes < request.thresholdInBytes) {
Slog.w(TAG, "Threshold was too low for " + request
+ ". Overriding to a safer default of " + thresholdInBytes + " bytes");
}
return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
request.template, thresholdInBytes);
}
private RequestInfo buildRequestInfo(DataUsageRequest request,
Messenger messenger, IBinder binder, int callingUid,
@NetworkStatsAccess.Level int accessLevel) {
if (accessLevel <= NetworkStatsAccess.Level.USER) {
return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
accessLevel);
} else {
// Safety check in case a new access level is added and we forgot to update this
checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
accessLevel);
}
}
/**
* Tracks information relevant to a data usage observer.
* It will notice when the calling process dies so we can self-expire.
*/
private abstract static class RequestInfo implements IBinder.DeathRecipient {
private final NetworkStatsObservers mStatsObserver;
protected final DataUsageRequest mRequest;
private final Messenger mMessenger;
private final IBinder mBinder;
protected final int mCallingUid;
protected final @NetworkStatsAccess.Level int mAccessLevel;
protected NetworkStatsRecorder mRecorder;
protected NetworkStatsCollection mCollection;
RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
Messenger messenger, IBinder binder, int callingUid,
@NetworkStatsAccess.Level int accessLevel) {
mStatsObserver = statsObserver;
mRequest = request;
mMessenger = messenger;
mBinder = binder;
mCallingUid = callingUid;
mAccessLevel = accessLevel;
try {
mBinder.linkToDeath(this, 0);
} catch (RemoteException e) {
binderDied();
}
}
@Override
public void binderDied() {
if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
+ mRequest + ", " + mBinder + ")");
mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
callCallback(NetworkStatsManager.CALLBACK_RELEASED);
}
@Override
public String toString() {
return "RequestInfo from uid:" + mCallingUid
+ " for " + mRequest + " accessLevel:" + mAccessLevel;
}
private void unlinkDeathRecipient() {
if (mBinder != null) {
mBinder.unlinkToDeath(this, 0);
}
}
/**
* Update stats given the samples and interface to identity mappings.
*/
private void updateStats(StatsContext statsContext) {
if (mRecorder == null) {
// First run; establish baseline stats
resetRecorder();
recordSample(statsContext);
return;
}
recordSample(statsContext);
if (checkStats()) {
resetRecorder();
callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
}
}
private void callCallback(int callbackType) {
Bundle bundle = new Bundle();
bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
Message msg = Message.obtain();
msg.what = callbackType;
msg.setData(bundle);
try {
if (LOGV) {
Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
+ " for " + mRequest);
}
mMessenger.send(msg);
} catch (RemoteException e) {
// May occur naturally in the race of binder death.
Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
}
}
private void resetRecorder() {
mRecorder = new NetworkStatsRecorder();
mCollection = mRecorder.getSinceBoot();
}
protected abstract boolean checkStats();
protected abstract void recordSample(StatsContext statsContext);
private String callbackTypeToName(int callbackType) {
switch (callbackType) {
case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
return "LIMIT_REACHED";
case NetworkStatsManager.CALLBACK_RELEASED:
return "RELEASED";
default:
return "UNKNOWN";
}
}
}
private static class NetworkUsageRequestInfo extends RequestInfo {
NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
Messenger messenger, IBinder binder, int callingUid,
@NetworkStatsAccess.Level int accessLevel) {
super(statsObserver, request, messenger, binder, callingUid, accessLevel);
}
@Override
protected boolean checkStats() {
long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
if (LOGV) {
Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
+ mRequest.template);
}
if (bytesSoFar > mRequest.thresholdInBytes) {
return true;
}
return false;
}
@Override
protected void recordSample(StatsContext statsContext) {
// Recorder does not need to be locked in this context since only the handler
// thread will update it. We pass a null VPN array because usage is aggregated by uid
// for this snapshot, so VPN traffic can't be reattributed to responsible apps.
mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
statsContext.mCurrentTime);
}
/**
* Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
* over all buckets, which in this case should be only one since we built it big enough
* that it will outlive the caller. If it doesn't, then there will be multiple buckets.
*/
private long getTotalBytesForNetwork(NetworkTemplate template) {
NetworkStats stats = mCollection.getSummary(template,
Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
mAccessLevel, mCallingUid);
return stats.getTotalBytes();
}
}
private static class UserUsageRequestInfo extends RequestInfo {
UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
Messenger messenger, IBinder binder, int callingUid,
@NetworkStatsAccess.Level int accessLevel) {
super(statsObserver, request, messenger, binder, callingUid, accessLevel);
}
@Override
protected boolean checkStats() {
int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
for (int i = 0; i < uidsToMonitor.length; i++) {
long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
if (bytesSoFar > mRequest.thresholdInBytes) {
return true;
}
}
return false;
}
@Override
protected void recordSample(StatsContext statsContext) {
// Recorder does not need to be locked in this context since only the handler
// thread will update it. We pass the VPN info so VPN traffic is reattributed to
// responsible apps.
mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
statsContext.mCurrentTime);
}
/**
* Reads all stats matching the given template and uid. Ther history will likely only
* contain one bucket per ident since we build it big enough that it will outlive the
* caller lifetime.
*/
private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
try {
NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
NetworkStatsHistory.FIELD_ALL,
Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
mAccessLevel, mCallingUid);
return history.getTotalBytes();
} catch (SecurityException e) {
if (LOGV) {
Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
+ uid);
}
return 0;
}
}
}
private static class StatsContext {
NetworkStats mXtSnapshot;
NetworkStats mUidSnapshot;
ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
long mCurrentTime;
StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
ArrayMap<String, NetworkIdentitySet> activeIfaces,
ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
long currentTime) {
mXtSnapshot = xtSnapshot;
mUidSnapshot = uidSnapshot;
mActiveIfaces = activeIfaces;
mActiveUidIfaces = activeUidIfaces;
mCurrentTime = currentTime;
}
}
}