Merge "Use a dedicated class to store stats"
diff --git a/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt b/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt
index 8ee6670..eb24d08 100644
--- a/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt
+++ b/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt
@@ -154,8 +154,8 @@
         // Wait for 464Xlat to be ready.
         val internalInterfaceName = waitFor464XlatReady(packetBridge.internalNetwork)
 
-        val (_, rxBytesBeforeTest) = getTotalTxRxBytes(internalInterfaceName)
-        val (_, rxTaggedBytesBeforeTest) = getTaggedTxRxBytes(internalInterfaceName, TEST_TAG)
+        val statsBeforeTest = getNetworkSummary(internalInterfaceName)
+        val taggedStatsBeforeTest = getTaggedNetworkSummary(internalInterfaceName, TEST_TAG)
 
         // Generate the download traffic.
         genHttpTraffic(packetBridge.internalNetwork, uploadSize = 0L, TEST_DOWNLOAD_SIZE)
@@ -163,21 +163,17 @@
         // In practice, for one way 10k download payload, the download usage is about
         // 11222~12880 bytes. And the upload usage is about 1279~1626 bytes, which is majorly
         // contributed by TCP ACK packets.
-        val (txBytesAfterDownload, rxBytesAfterDownload) =
-            getTotalTxRxBytes(internalInterfaceName)
-        val (txTaggedBytesAfterDownload, rxTaggedBytesAfterDownload) = getTaggedTxRxBytes(
-            internalInterfaceName,
-            TEST_TAG
-        )
+        val statsAfterDownload = getNetworkSummary(internalInterfaceName)
+        val taggedStatsAfterDownload = getTaggedNetworkSummary(internalInterfaceName, TEST_TAG)
         assertInRange(
             "Download size", internalInterfaceName,
-            rxBytesAfterDownload - rxBytesBeforeTest,
+            statsAfterDownload.rxBytes - statsBeforeTest.rxBytes,
             TEST_DOWNLOAD_SIZE, (TEST_DOWNLOAD_SIZE * 1.3).toLong()
         )
         // Increment of tagged data should be zero since no tagged traffic was generated.
         assertEquals(
-            rxTaggedBytesBeforeTest,
-            rxTaggedBytesAfterDownload,
+            taggedStatsBeforeTest.rxBytes,
+            taggedStatsAfterDownload.rxBytes,
             "Tagged download size of uid ${Process.myUid()} on $internalInterfaceName"
         )
 
@@ -190,17 +186,17 @@
         )
 
         // Verify upload data usage accounting.
-        val (txBytesAfterUpload, _) = getTotalTxRxBytes(internalInterfaceName)
-        val (txTaggedBytesAfterUpload, _) = getTaggedTxRxBytes(internalInterfaceName, TEST_TAG)
+        val statsAfterUpload = getNetworkSummary(internalInterfaceName)
+        val taggedStatsAfterUpload = getTaggedNetworkSummary(internalInterfaceName, TEST_TAG)
         assertInRange(
             "Upload size", internalInterfaceName,
-            txBytesAfterUpload - txBytesAfterDownload,
+            statsAfterUpload.txBytes - statsAfterDownload.txBytes,
             TEST_UPLOAD_SIZE, (TEST_UPLOAD_SIZE * 1.3).toLong()
         )
         assertInRange(
             "Tagged upload size of uid ${Process.myUid()}",
             internalInterfaceName,
-            txTaggedBytesAfterUpload - txTaggedBytesAfterDownload,
+            taggedStatsAfterUpload.txBytes - taggedStatsAfterDownload.txBytes,
             TEST_UPLOAD_SIZE,
             (TEST_UPLOAD_SIZE * 1.3).toLong()
         )
@@ -259,13 +255,42 @@
         }
     }
 
-    private fun getTotalTxRxBytes(iface: String): Pair<Long, Long> {
+    // NetworkStats.Bucket cannot be written. So another class is needed to
+    // perform arithmetic operations.
+    private data class BareStats(
+        val rxBytes: Long,
+        val rxPackets: Long,
+        val txBytes: Long,
+        val txPackets: Long
+    ) {
+        operator fun plus(other: BareStats): BareStats {
+            return BareStats(
+                this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets,
+                this.txBytes + other.txBytes, this.txPackets + other.txPackets
+            )
+        }
+
+        companion object{
+            val EMPTY = BareStats(0L, 0L, 0L, 0L)
+        }
+    }
+
+    // Helper function to iterate buckets in app.usage.NetworkStats.
+    private fun NetworkStats.buckets() = object : Iterable<NetworkStats.Bucket> {
+        override fun iterator() = object : Iterator<NetworkStats.Bucket> {
+            override operator fun hasNext() = hasNextBucket()
+            override operator fun next() =
+                NetworkStats.Bucket().also { assertTrue(getNextBucket(it)) }
+        }
+    }
+
+    private fun getNetworkSummary(iface: String): BareStats {
         return getNetworkStatsThat(iface, TAG_NONE) { nsm, template ->
             nsm.querySummary(template, Long.MIN_VALUE, Long.MAX_VALUE)
         }
     }
 
-    private fun getTaggedTxRxBytes(iface: String, tag: Int): Pair<Long, Long> {
+    private fun getTaggedNetworkSummary(iface: String, tag: Int): BareStats {
         return getNetworkStatsThat(iface, tag) { nsm, template ->
             nsm.queryTaggedSummary(template, Long.MIN_VALUE, Long.MAX_VALUE)
         }
@@ -275,22 +300,21 @@
         iface: String,
         tag: Int,
         queryApi: (nsm: NetworkStatsManager, template: NetworkTemplate) -> NetworkStats
-    ): Pair<Long, Long> {
+    ): BareStats {
         val nsm = context.getSystemService(NetworkStatsManager::class.java)
         nsm.forceUpdate()
         val testTemplate = NetworkTemplate.Builder(MATCH_TEST)
             .setWifiNetworkKeys(setOf(iface)).build()
         val stats = queryApi.invoke(nsm, testTemplate)
-        val recycled = NetworkStats.Bucket()
-        var rx = 0L
-        var tx = 0L
-        while (stats.hasNextBucket()) {
-            stats.getNextBucket(recycled)
-            if (recycled.uid != Process.myUid() || recycled.tag != tag) continue
-            rx += recycled.rxBytes
-            tx += recycled.txBytes
+        val filteredBuckets = stats.buckets().filter { it.uid == Process.myUid() && it.tag == tag }
+        return filteredBuckets.fold(BareStats.EMPTY) { acc, it ->
+            acc + BareStats(
+                it.rxBytes,
+                it.rxPackets,
+                it.txBytes,
+                it.txPackets
+            )
         }
-        return tx to rx
     }
 
     /** Verify the given value is in range [lower, upper]  */