BpfBitmap: some minor optimizations

Test: TreeHugger, atest BpfBitmapTest
Flag: EXEMPT mainline
Signed-off-by: Maciej Żenczykowski <maze@google.com>
Change-Id: I5acf5bcb31777e503e58cef733d617a53006d2b9
diff --git a/Tethering/jni/onload.cpp b/Tethering/jni/onload.cpp
index fd40d41..606cc44 100644
--- a/Tethering/jni/onload.cpp
+++ b/Tethering/jni/onload.cpp
@@ -23,6 +23,7 @@
 namespace android {
 
 int register_com_android_net_module_util_BpfMap(JNIEnv* env, char const* class_name);
+int register_com_android_net_module_util_BpfBitmap(JNIEnv* env, char const* class_name);
 int register_com_android_net_module_util_TcUtils(JNIEnv* env, char const* class_name);
 int register_com_android_networkstack_tethering_BpfCoordinator(JNIEnv* env);
 int register_com_android_networkstack_tethering_util_TetheringUtils(JNIEnv* env);
@@ -39,6 +40,9 @@
     if (register_com_android_net_module_util_BpfMap(env,
             "com/android/networkstack/tethering/util/BpfMap") < 0) return JNI_ERR;
 
+    if (register_com_android_net_module_util_BpfBitmap(env,
+            "com/android/networkstack/tethering/util/BpfBitmap") < 0) return JNI_ERR;
+
     if (register_com_android_net_module_util_TcUtils(env,
             "com/android/networkstack/tethering/util/TcUtils") < 0) return JNI_ERR;
 
diff --git a/service/jni/com_android_net_module_util/onload.cpp b/service/jni/com_android_net_module_util/onload.cpp
index d91eb03..f160622 100644
--- a/service/jni/com_android_net_module_util/onload.cpp
+++ b/service/jni/com_android_net_module_util/onload.cpp
@@ -20,6 +20,7 @@
 namespace android {
 
 int register_com_android_net_module_util_BpfMap(JNIEnv* env, char const* class_name);
+int register_com_android_net_module_util_BpfBitmap(JNIEnv* env, char const* class_name);
 int register_com_android_net_module_util_TcUtils(JNIEnv* env, char const* class_name);
 int register_com_android_net_module_util_BpfUtils(JNIEnv* env, char const* class_name);
 
@@ -33,6 +34,9 @@
     if (register_com_android_net_module_util_BpfMap(env,
             "android/net/connectivity/com/android/net/module/util/BpfMap") < 0) return JNI_ERR;
 
+    if (register_com_android_net_module_util_BpfBitmap(env,
+            "android/net/connectivity/com/android/net/module/util/BpfBitmap") < 0) return JNI_ERR;
+
     if (register_com_android_net_module_util_TcUtils(env,
             "android/net/connectivity/com/android/net/module/util/TcUtils") < 0) return JNI_ERR;
 
diff --git a/staticlibs/device/com/android/net/module/util/BpfBitmap.java b/staticlibs/device/com/android/net/module/util/BpfBitmap.java
index b62a430..8f45c65 100644
--- a/staticlibs/device/com/android/net/module/util/BpfBitmap.java
+++ b/staticlibs/device/com/android/net/module/util/BpfBitmap.java
@@ -22,6 +22,8 @@
 import androidx.annotation.NonNull;
 import androidx.annotation.RequiresApi;
 
+import dalvik.annotation.optimization.CriticalNative;
+
  /**
  *
  * Generic bitmap class for use with BPF programs. Corresponds to a BpfMap
@@ -30,7 +32,12 @@
  */
 @RequiresApi(Build.VERSION_CODES.S)
 public class BpfBitmap {
-    private BpfMap<Struct.S32, Struct.S64> mBpfMap;
+    static {
+        System.loadLibrary(JniUtil.getJniLibraryName(BpfBitmap.class.getPackage()));
+    }
+
+    private final BpfMap<Struct.S32, Struct.S64> mBpfMap;
+    private final int mMapFd;
 
     /**
      * Create a BpfBitmap map wrapper with "path" of filesystem.
@@ -39,21 +46,16 @@
      */
     public BpfBitmap(@NonNull String path) throws ErrnoException {
         mBpfMap = new BpfMap<>(path, Struct.S32.class, Struct.S64.class);
+        mMapFd = mBpfMap.getFd();
     }
 
-    /**
-     * Retrieves the value from BpfMap for the given key.
-     *
-     * @param key The key in the map corresponding to the value to return.
-     */
-    private long getBpfMapValue(Struct.S32 key) throws ErrnoException  {
-        Struct.S64 curVal = mBpfMap.getValue(key);
-        if (curVal != null) {
-            return curVal.val;
-        } else {
-            return 0;
-        }
-    }
+    // Returns > 0 if bit is set, 0 if not set, < 0 on error (negative errno).
+    @CriticalNative
+    private static native int nativeGet(int fd, int index);
+
+    // Returns 0 on success, < 0 on error (negative errno).
+    @CriticalNative
+    private static native int nativeSet(int fd, int index, boolean set);
 
     /**
      * Retrieves the bit for the given index in the bitmap.
@@ -63,8 +65,26 @@
     public boolean get(int index) throws ErrnoException  {
         if (index < 0) return false;
 
-        Struct.S32 key = new Struct.S32(index >> 6);
-        return ((getBpfMapValue(key) >>> (index & 63)) & 1L) != 0;
+        final int ret = nativeGet(mMapFd, index);
+        if (ret < 0) {
+            throw new ErrnoException("nativeGet", -ret);
+        }
+        return ret > 0;
+    }
+
+    /**
+     * Change the specified index in the bitmap to set value.
+     *
+     * @param index Position to (un)set in bitmap.
+     * @param set Boolean indicating to set or unset index.
+     */
+    public void set(int index, boolean set) throws ErrnoException {
+        if (index < 0) throw new IllegalArgumentException("Index out of bounds.");
+
+        final int ret = nativeSet(mMapFd, index, set);
+        if (ret < 0) {
+            throw new ErrnoException("nativeSet", -ret);
+        }
     }
 
     /**
@@ -86,22 +106,6 @@
     }
 
     /**
-     * Change the specified index in the bitmap to set value.
-     *
-     * @param index Position to unset in bitmap.
-     * @param set Boolean indicating to set or unset index.
-     */
-    public void set(int index, boolean set) throws ErrnoException {
-        if (index < 0) throw new IllegalArgumentException("Index out of bounds.");
-
-        Struct.S32 key = new Struct.S32(index >> 6);
-        long mask = (1L << (index & 63));
-        long val = getBpfMapValue(key);
-        if (set) val |= mask; else val &= ~mask;
-        mBpfMap.updateEntry(key, new Struct.S64(val));
-    }
-
-    /**
      * Clears the map. The map may already be empty.
      *
      * @throws ErrnoException if updating entry to 0 fails.
@@ -118,7 +122,8 @@
     public boolean isEmpty() throws ErrnoException {
         Struct.S32 key = mBpfMap.getFirstKey();
         while (key != null) {
-            if (getBpfMapValue(key) != 0) {
+            Struct.S64 val = mBpfMap.getValue(key);
+            if (val != null && val.val != 0) {
                 return false;
             }
             key = mBpfMap.getNextKey(key);
diff --git a/staticlibs/device/com/android/net/module/util/BpfMap.java b/staticlibs/device/com/android/net/module/util/BpfMap.java
index da04c73..72cb430 100644
--- a/staticlibs/device/com/android/net/module/util/BpfMap.java
+++ b/staticlibs/device/com/android/net/module/util/BpfMap.java
@@ -71,6 +71,11 @@
     private final int mKeySize;
     private final int mValueSize;
 
+    // The following is (ab)used by BpfBitmap.java
+    /* package */ int getFd() {
+        return mMapFd.getFd();
+    }
+
     private static ConcurrentHashMap<Pair<String, Integer>, ParcelFileDescriptor> sFdCache =
             new ConcurrentHashMap<>();
 
diff --git a/staticlibs/native/bpfmapjni/Android.bp b/staticlibs/native/bpfmapjni/Android.bp
index 9a58a93..2612ec5 100644
--- a/staticlibs/native/bpfmapjni/Android.bp
+++ b/staticlibs/native/bpfmapjni/Android.bp
@@ -20,7 +20,7 @@
 cc_library_static {
     name: "libnet_utils_device_common_bpfjni",
     srcs: [
-        "com_android_net_module_util_BpfMap.cpp",
+        "com_android_net_module_util_Bpf.cpp",
         "com_android_net_module_util_TcUtils.cpp",
     ],
     header_libs: [
diff --git a/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp b/staticlibs/native/bpfmapjni/com_android_net_module_util_Bpf.cpp
similarity index 77%
rename from staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp
rename to staticlibs/native/bpfmapjni/com_android_net_module_util_Bpf.cpp
index 38cc92c..f0e7246 100644
--- a/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp
+++ b/staticlibs/native/bpfmapjni/com_android_net_module_util_Bpf.cpp
@@ -16,6 +16,7 @@
 
 #include <errno.h>
 #include <linux/pfkeyv2.h>
+#include <stdint.h>
 #include <sys/socket.h>
 #include <jni.h>
 #include <nativehelper/JNIHelp.h>
@@ -143,10 +144,47 @@
     return 0;
 }
 
+static jint com_android_net_module_util_BpfBitmap_nativeGet(jint fd, jint index) {
+    if (index < 0) return -EINVAL;
+
+    const uint32_t key = index >> 6;
+    const uint32_t subkey = index & 63;
+    uint64_t value = 0;
+
+    // findMapEntry returns 0 on success, -1 on error.
+    // If the entry does not exist, it's not an error, the value is just 0.
+    if (bpf::findMapEntry(fd, &key, &value) && errno != ENOENT) return -errno;
+
+    return (value >> subkey) & 1;
+}
+
+static jint com_android_net_module_util_BpfBitmap_nativeSet(jint fd, jint index, jboolean set) {
+    if (index < 0) return -EINVAL;
+
+    const uint32_t key = index >> 6;
+    const uint32_t subkey = index & 63;
+    uint64_t value = 0;
+
+    // Read the existing value. It's okay if it doesn't exist, value will be 0.
+    if (bpf::findMapEntry(fd, &key, &value) && errno != ENOENT) return -errno;
+
+    const uint64_t mask = 1uLL << subkey;
+    if (set) {
+        value |= mask;
+    } else {
+        value &= ~mask;
+    }
+
+    // Write the updated value back. BPF_ANY will create or update as needed.
+    if (bpf::writeToMapEntry(fd, &key, &value, BPF_ANY)) return -errno;
+
+    return 0;
+}
+
 /*
  * JNI registration.
  */
-static const JNINativeMethod gMethods[] = {
+static const JNINativeMethod gBpfMapMethods[] = {
     /* name, signature, funcPtr */
     { "nativeBpfFdGet", "(Ljava/lang/String;III)I",
         (void*) com_android_net_module_util_BpfMap_nativeBpfFdGet },
@@ -165,9 +203,19 @@
 };
 
 int register_com_android_net_module_util_BpfMap(JNIEnv* env, char const* class_name) {
-    return jniRegisterNativeMethods(env,
-            class_name,
-            gMethods, NELEM(gMethods));
+    return jniRegisterNativeMethods(env, class_name, gBpfMapMethods, NELEM(gBpfMapMethods));
+}
+
+static const JNINativeMethod gBpfBitmapMethods[] = {
+    /* name, signature, funcPtr */
+    // CriticalNative
+    { "nativeGet", "(II)I", (void*)com_android_net_module_util_BpfBitmap_nativeGet },
+    // CriticalNative
+    { "nativeSet", "(IIZ)I", (void*)com_android_net_module_util_BpfBitmap_nativeSet },
+};
+
+int register_com_android_net_module_util_BpfBitmap(JNIEnv* env, char const* class_name) {
+    return jniRegisterNativeMethods(env, class_name, gBpfBitmapMethods, NELEM(gBpfBitmapMethods));
 }
 
 }; // namespace android
diff --git a/tests/common/tethering-jni-jarjar-rules.txt b/tests/common/tethering-jni-jarjar-rules.txt
index 593ba14..3dfd77a 100644
--- a/tests/common/tethering-jni-jarjar-rules.txt
+++ b/tests/common/tethering-jni-jarjar-rules.txt
@@ -5,6 +5,7 @@
 # jarjar rules the test is using, but this is a bit less realistic (using a different JNI library),
 # and complicates the test build. It would be necessary if TetheringUtils had a different package
 # name in test code though, as the JNI library name is deducted from the TetheringUtils package.
+rule com.android.net.module.util.BpfBitmap* com.android.networkstack.tethering.util.BpfBitmap@1
 rule com.android.net.module.util.BpfMap* com.android.networkstack.tethering.util.BpfMap@1
 rule com.android.net.module.util.BpfUtils* com.android.networkstack.tethering.util.BpfUtils@1
 rule com.android.net.module.util.TcUtils* com.android.networkstack.tethering.util.TcUtils@1
diff --git a/tests/unit/jni/android_net_frameworktests_util/onload.cpp b/tests/unit/jni/android_net_frameworktests_util/onload.cpp
index f70b04b..b79a230 100644
--- a/tests/unit/jni/android_net_frameworktests_util/onload.cpp
+++ b/tests/unit/jni/android_net_frameworktests_util/onload.cpp
@@ -23,6 +23,7 @@
 namespace android {
 
 int register_com_android_net_module_util_BpfMap(JNIEnv* env, char const* class_name);
+int register_com_android_net_module_util_BpfBitmap(JNIEnv* env, char const* class_name);
 int register_com_android_net_module_util_TcUtils(JNIEnv* env, char const* class_name);
 int register_com_android_net_module_util_ServiceConnectivityJni(JNIEnv *env,
                                                       char const *class_name);
@@ -37,6 +38,9 @@
     if (register_com_android_net_module_util_BpfMap(env,
             "android/net/frameworktests/util/BpfMap") < 0) return JNI_ERR;
 
+    if (register_com_android_net_module_util_BpfBitmap(env,
+            "android/net/frameworktests/util/BpfBitmap") < 0) return JNI_ERR;
+
     if (register_com_android_net_module_util_TcUtils(env,
             "android/net/frameworktests/util/TcUtils") < 0) return JNI_ERR;