Merge "Fix data stall false alarm caused by TCP signal" am: c062a7341c

Original change: https://android-review.googlesource.com/c/platform/packages/modules/NetworkStack/+/2264841

Change-Id: Iff48855469eae2ebd1ac359ed712f5d78519c6ba
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/src/com/android/networkstack/netlink/TcpSocketTracker.java b/src/com/android/networkstack/netlink/TcpSocketTracker.java
index 33d3e61..b9fbde4 100644
--- a/src/com/android/networkstack/netlink/TcpSocketTracker.java
+++ b/src/com/android/networkstack/netlink/TcpSocketTracker.java
@@ -39,7 +39,10 @@
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
 
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
 import android.net.INetd;
 import android.net.MarkMaskParcel;
 import android.net.Network;
@@ -47,6 +50,7 @@
 import android.os.AsyncTask;
 import android.os.Build;
 import android.os.IBinder;
+import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.SystemClock;
 import android.provider.DeviceConfig;
@@ -60,6 +64,7 @@
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.net.module.util.netlink.NetlinkConstants;
@@ -128,6 +133,10 @@
     private final int mNetworkMask;
     private int mMinPacketsThreshold = DEFAULT_DATA_STALL_MIN_PACKETS_THRESHOLD;
     private int mTcpPacketsFailRateThreshold = DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE;
+
+    private final Object mDozeModeLock = new Object();
+    @GuardedBy("mDozeModeLock")
+    private boolean mInDozeMode = false;
     @VisibleForTesting
     protected final DeviceConfig.OnPropertiesChangedListener mConfigListener =
             new DeviceConfig.OnPropertiesChangedListener() {
@@ -144,6 +153,19 @@
                 }
             };
 
+    final BroadcastReceiver mDeviceIdleReceiver = new BroadcastReceiver() {
+        @Override
+        public void onReceive(Context context, Intent intent) {
+            if (intent == null) return;
+
+            if (PowerManager.ACTION_DEVICE_IDLE_MODE_CHANGED.equals(intent.getAction())) {
+                final PowerManager powerManager = context.getSystemService(PowerManager.class);
+                final boolean deviceIdle = powerManager.isDeviceIdleMode();
+                setDozeMode(deviceIdle);
+            }
+        }
+    };
+
     public TcpSocketTracker(@NonNull final Dependencies dps, @NonNull final Network network) {
         mDependencies = dps;
         mNetwork = network;
@@ -172,6 +194,7 @@
                             TCP_MONITOR_STATE_FILTER));
         }
         mDependencies.addDeviceConfigChangedListener(mConfigListener);
+        mDependencies.addDeviceIdleReceiver(mDeviceIdleReceiver);
     }
 
     @Nullable
@@ -191,10 +214,17 @@
      * Request to send a SockDiag Netlink request. Receive and parse the returned message. This
      * function is not thread-safe and should only be called from only one thread.
      *
-     * @Return if this polling request executes successfully or not.
+     * @Return if this polling request is sent to kernel and executes successfully or not.
      */
     public boolean pollSocketsInfo() {
         if (!mDependencies.isTcpInfoParsingSupported()) return false;
+        // Traffic will be restricted in doze mode. TCP info may not reflect the correct network
+        // behavior.
+        // TODO: Traffic may be restricted by other reason. Get the restriction info from bpf in T+.
+        synchronized (mDozeModeLock) {
+            if (mInDozeMode) return false;
+        }
+
         FileDescriptor fd = null;
 
         try {
@@ -358,6 +388,14 @@
      */
     public boolean isDataStallSuspected() {
         if (!mDependencies.isTcpInfoParsingSupported()) return false;
+
+        // Skip checking data stall since the traffic will be restricted and it will not be real
+        // network stall.
+        // TODO: Traffic may be restricted by other reason. Get the restriction info from bpf in T+.
+        synchronized (mDozeModeLock) {
+            if (mInDozeMode) return false;
+        }
+
         return (getLatestPacketFailPercentage() >= getTcpPacketsFailRateThreshold());
     }
 
@@ -468,6 +506,7 @@
     /** Stops monitoring and releases resources. */
     public void quit() {
         mDependencies.removeDeviceConfigChangedListener(mConfigListener);
+        mDependencies.removeBroadcastReceiver(mDeviceIdleReceiver);
     }
 
     /**
@@ -564,6 +603,14 @@
         }
     }
 
+    private void setDozeMode(boolean isEnabled) {
+        synchronized (mDozeModeLock) {
+            if (mInDozeMode == isEnabled) return;
+            mInDozeMode = isEnabled;
+            log("Doze mode enabled=" + mInDozeMode);
+        }
+    }
+
     /**
      * Dependencies class for testing.
      */
@@ -658,5 +705,16 @@
                 @NonNull final DeviceConfig.OnPropertiesChangedListener listener) {
             DeviceConfig.removeOnPropertiesChangedListener(listener);
         }
+
+        /** Add receiver for detecting doze mode change to control TCP detection. */
+        public void addDeviceIdleReceiver(@NonNull final BroadcastReceiver receiver) {
+            mContext.registerReceiver(receiver,
+                    new IntentFilter(PowerManager.ACTION_DEVICE_IDLE_MODE_CHANGED));
+        }
+
+        /** Remove broadcast receiver. */
+        public void removeBroadcastReceiver(@NonNull final BroadcastReceiver receiver) {
+            mContext.unregisterReceiver(receiver);
+        }
     }
 }
diff --git a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
index 1177f96..7773de4 100644
--- a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
+++ b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
@@ -21,8 +21,6 @@
 import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
 import static android.system.OsConstants.AF_INET;
 
-import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
-
 import static com.android.net.module.util.netlink.NetlinkConstants.SOCKDIAG_MSG_HEADER_SIZE;
 
 import static junit.framework.Assert.assertEquals;
@@ -39,10 +37,14 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import android.content.BroadcastReceiver;
+import android.content.Context;
+import android.content.Intent;
 import android.net.INetd;
 import android.net.MarkMaskParcel;
 import android.net.Network;
 import android.os.Build;
+import android.os.PowerManager;
 import android.util.Log;
 import android.util.Log.TerribleFailureHandler;
 
@@ -62,6 +64,7 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -190,6 +193,8 @@
     private final Network mNetwork = new Network(TEST_NETID1);
     private final Network mOtherNetwork = new Network(TEST_NETID2);
     private TerribleFailureHandler mOldWtfHandler;
+    @Mock private Context mContext;
+    @Mock private PowerManager mPowerManager;
 
     @Rule
     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
@@ -211,6 +216,7 @@
 
         when(mNetd.getFwmarkForNetwork(eq(TEST_NETID1)))
                 .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID1_FWMARK));
+        doReturn(mPowerManager).when(mContext).getSystemService(PowerManager.class);
     }
 
     @After
@@ -262,7 +268,7 @@
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q
     public void testPollSocketsInfo() throws Exception {
         // This test requires shims that provide API 30 access
-        assumeTrue(ConstantsShim.VERSION >= 30);
+        assumeTrue(ConstantsShim.VERSION >= Build.VERSION_CODES.R);
         when(mDependencies.isTcpInfoParsingSupported()).thenReturn(false);
         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
         assertFalse(tst.pollSocketsInfo());
@@ -283,9 +289,7 @@
         assertEquals(-1, tst.getLatestPacketFailPercentage());
         assertEquals(0, tst.getSentSinceLastRecv());
 
-        final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
-        final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
-        doReturn(tcpBufferV6, tcpBufferV4).when(mDependencies).recvMessage(any());
+        setupNormalTestTcpInfo();
         assertTrue(tst.pollSocketsInfo());
 
         assertEquals(10, tst.getSentSinceLastRecv());
@@ -317,15 +321,52 @@
         verifyNoMoreInteractions(mDependencies);
     }
 
+    @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
+    public void testTcpInfoParsingWithDozeMode() throws Exception {
+        // This test requires shims that provide API 30 access
+        assumeTrue(ConstantsShim.VERSION >= Build.VERSION_CODES.R);
+
+        final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
+        final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
+                ArgumentCaptor.forClass(BroadcastReceiver.class);
+
+        verify(mDependencies).addDeviceIdleReceiver(receiverCaptor.capture());
+        setupNormalTestTcpInfo();
+        assertTrue(tst.pollSocketsInfo());
+
+        // Lower the threshold.
+        when(mDependencies.getDeviceConfigPropertyInt(any(), eq(CONFIG_TCP_PACKETS_FAIL_PERCENTAGE),
+                anyInt())).thenReturn(40);
+
+        // Trigger a config update
+        tst.mConfigListener.onPropertiesChanged(null /* properties */);
+        assertEquals(10, tst.getSentSinceLastRecv());
+        assertEquals(50, tst.getLatestPacketFailPercentage());
+        assertTrue(tst.isDataStallSuspected());
+
+        // Enable doze mode
+        doReturn(true).when(mPowerManager).isDeviceIdleMode();
+        final BroadcastReceiver receiver = receiverCaptor.getValue();
+        receiver.onReceive(mContext, new Intent(PowerManager.ACTION_DEVICE_IDLE_MODE_CHANGED));
+        assertFalse(tst.pollSocketsInfo());
+        assertFalse(tst.isDataStallSuspected());
+    }
+
+    private void setupNormalTestTcpInfo() throws Exception {
+        final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
+        final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
+        doReturn(tcpBufferV6, tcpBufferV4).when(mDependencies).recvMessage(any());
+    }
+
     @Test @IgnoreAfter(Build.VERSION_CODES.Q)
     public void testTcpInfoParsingNotSupportedOnQ() {
-        assertFalse(new TcpSocketTracker.Dependencies(getInstrumentation().getContext())
+        assertFalse(new TcpSocketTracker.Dependencies(mContext)
                 .isTcpInfoParsingSupported());
     }
 
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
     public void testTcpInfoParsingSupportedFromR() {
-        assertTrue(new TcpSocketTracker.Dependencies(getInstrumentation().getContext())
+        assertTrue(new TcpSocketTracker.Dependencies(mContext)
                 .isTcpInfoParsingSupported());
     }
 
@@ -359,11 +400,9 @@
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q) // TCP info parsing is not supported on Q
     public void testPollSocketsInfo_BadFormat() throws Exception {
         // This test requires shims that provide API 30 access
-        assumeTrue(ConstantsShim.VERSION >= 30);
+        assumeTrue(ConstantsShim.VERSION >= Build.VERSION_CODES.R);
         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork);
-        final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
-        final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
-        doReturn(tcpBufferV6, tcpBufferV4).when(mDependencies).recvMessage(any());
+        setupNormalTestTcpInfo();
         assertTrue(tst.pollSocketsInfo());
         assertEquals(10, tst.getSentSinceLastRecv());
         assertEquals(50, tst.getLatestPacketFailPercentage());
@@ -383,9 +422,7 @@
         when(mNetd.getFwmarkForNetwork(eq(TEST_NETID2)))
                 .thenReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID2_FWMARK));
         final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mOtherNetwork);
-        final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
-        final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
-        doReturn(tcpBufferV6, tcpBufferV4).when(mDependencies).recvMessage(any());
+        setupNormalTestTcpInfo();
         assertTrue(tst.pollSocketsInfo());
 
         assertEquals(0, tst.getSentSinceLastRecv());