Use ML model for the Back Gesture in EdgeBackGestureHandler.

Change-Id: I2fd5255e903932c03a35ae463b0eff3840dc81bd
Merged-In: I2fd5255e903932c03a35ae463b0eff3840dc81bd
Test: manual model test and getting the results
Bug: 150170384
diff --git a/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java b/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
index b4cd145..3c5eba7 100644
--- a/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
+++ b/core/java/com/android/internal/config/sysui/SystemUiDeviceConfigFlags.java
@@ -428,6 +428,17 @@
      */
     public static final String SCREENSHOT_KEYCHORD_DELAY = "screenshot_keychord_delay";
 
+    /**
+     * (boolean) Whether to use an ML model for the Back Gesture.
+     */
+    public static final String USE_BACK_GESTURE_ML_MODEL = "use_back_gesture_ml_model";
+
+    /**
+     * (float) Threshold for Back Gesture ML model prediction.
+     */
+    public static final String BACK_GESTURE_ML_MODEL_THRESHOLD = "back_gesture_ml_model_threshold";
+
+
     private SystemUiDeviceConfigFlags() {
     }
 }
diff --git a/packages/SystemUI/src/com/android/systemui/SystemUIFactory.java b/packages/SystemUI/src/com/android/systemui/SystemUIFactory.java
index 5674fdd..4a77de2 100644
--- a/packages/SystemUI/src/com/android/systemui/SystemUIFactory.java
+++ b/packages/SystemUI/src/com/android/systemui/SystemUIFactory.java
@@ -18,6 +18,7 @@
 
 import android.annotation.NonNull;
 import android.content.Context;
+import android.content.res.AssetManager;
 import android.content.res.Resources;
 import android.os.Handler;
 import android.os.Looper;
@@ -39,6 +40,7 @@
 import com.android.systemui.statusbar.NotificationListener;
 import com.android.systemui.statusbar.NotificationMediaManager;
 import com.android.systemui.statusbar.notification.NotificationWakeUpCoordinator;
+import com.android.systemui.statusbar.phone.BackGestureTfClassifierProvider;
 import com.android.systemui.statusbar.phone.DozeParameters;
 import com.android.systemui.statusbar.phone.KeyguardBouncer;
 import com.android.systemui.statusbar.phone.KeyguardBypassController;
@@ -182,4 +184,13 @@
             return mContext;
         }
     }
+
+    /**
+     * Creates an instance of BackGestureTfClassifierProvider.
+     * This method is overridden in vendor specific implementation of Sys UI.
+     */
+    public BackGestureTfClassifierProvider createBackGestureTfClassifierProvider(
+            AssetManager am) {
+        return new BackGestureTfClassifierProvider();
+    }
 }
diff --git a/packages/SystemUI/src/com/android/systemui/statusbar/phone/BackGestureTfClassifierProvider.java b/packages/SystemUI/src/com/android/systemui/statusbar/phone/BackGestureTfClassifierProvider.java
new file mode 100644
index 0000000..0af79c38
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/statusbar/phone/BackGestureTfClassifierProvider.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.systemui.statusbar.phone;
+
+import android.content.res.AssetManager;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * This class can be overridden by a vendor-specific sys UI implementation,
+ * in order to provide classification models for the Back Gesture.
+ */
+public class BackGestureTfClassifierProvider {
+    private static final String TAG = "BackGestureTfClassifierProvider";
+
+    /**
+     * Default implementation that returns an empty map.
+     * This method is overridden in vendor-specific Sys UI implementation.
+     *
+     * @param am       An AssetManager to get the vocab file.
+    */
+    public Map<String, Integer> loadVocab(AssetManager am) {
+        return new HashMap<String, Integer>();
+    }
+
+    /**
+     * This method is overridden in vendor-specific Sys UI implementation.
+     *
+     * @param featuresVector   List of input features.
+     *
+    */
+    public float predict(Object[] featuresVector) {
+        return -1;
+    }
+
+    /**
+     * Interpreter owns resources. This method releases the resources after
+     * use to avoid memory leak.
+     * This method is overridden in vendor-specific Sys UI implementation.
+     *
+     */
+    public void release() {}
+
+    /**
+     * Returns whether to use the ML model for Back Gesture.
+     * This method is overridden in vendor-specific Sys UI implementation.
+     *
+     */
+    public boolean isActive() {
+        return false;
+    }
+}
diff --git a/packages/SystemUI/src/com/android/systemui/statusbar/phone/EdgeBackGestureHandler.java b/packages/SystemUI/src/com/android/systemui/statusbar/phone/EdgeBackGestureHandler.java
index 00a932cb..2641bca 100644
--- a/packages/SystemUI/src/com/android/systemui/statusbar/phone/EdgeBackGestureHandler.java
+++ b/packages/SystemUI/src/com/android/systemui/statusbar/phone/EdgeBackGestureHandler.java
@@ -56,6 +56,7 @@
 import com.android.internal.policy.GestureNavigationSettingsObserver;
 import com.android.systemui.Dependency;
 import com.android.systemui.R;
+import com.android.systemui.SystemUIFactory;
 import com.android.systemui.broadcast.BroadcastDispatcher;
 import com.android.systemui.bubbles.BubbleController;
 import com.android.systemui.model.SysUiState;
@@ -76,6 +77,7 @@
 import java.io.PrintWriter;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 /**
@@ -117,8 +119,31 @@
         public void onTaskStackChanged() {
             mGestureBlockingActivityRunning = isGestureBlockingActivityRunning();
         }
+        @Override
+        public void onTaskCreated(int taskId, ComponentName componentName) {
+            if (componentName != null) {
+                mPackageName = componentName.getPackageName();
+            } else {
+                mPackageName = "_UNKNOWN";
+            }
+        }
     };
 
+    private DeviceConfig.OnPropertiesChangedListener mOnPropertiesChangedListener =
+            new DeviceConfig.OnPropertiesChangedListener() {
+                @Override
+                public void onPropertiesChanged(DeviceConfig.Properties properties) {
+                    if (DeviceConfig.NAMESPACE_SYSTEMUI.equals(properties.getNamespace())
+                            && (properties.getKeyset().contains(
+                                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD)
+                            || properties.getKeyset().contains(
+                                    SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL))) {
+                        updateMLModelState();
+                    }
+                }
+            };
+
+
     private final Context mContext;
     private final OverviewProxyService mOverviewProxyService;
     private final Runnable mStateChangeCallback;
@@ -173,6 +198,13 @@
     private int mRightInset;
     private int mSysUiFlags;
 
+    // For Tf-Lite model.
+    private BackGestureTfClassifierProvider mBackGestureTfClassifierProvider;
+    private Map<String, Integer> mVocab;
+    private boolean mUseMLModel;
+    private float mMLModelThreshold;
+    private String mPackageName;
+
     private final GestureNavigationSettingsObserver mGestureNavigationSettingsObserver;
 
     private final NavigationEdgeBackPlugin.BackCallback mBackCallback =
@@ -230,7 +262,6 @@
                 Log.e(TAG, "Failed to add gesture blocking activities", e);
             }
         }
-
         mLongPressTimeout = Math.min(MAX_LONG_PRESS_TIMEOUT,
                 ViewConfiguration.getLongPressTimeout());
 
@@ -344,6 +375,7 @@
             mContext.getSystemService(DisplayManager.class).unregisterDisplayListener(this);
             mPluginManager.removePluginListener(this);
             ActivityManagerWrapper.getInstance().unregisterTaskStackListener(mTaskStackListener);
+            DeviceConfig.removeOnPropertiesChangedListener(mOnPropertiesChangedListener);
 
             try {
                 WindowManagerGlobal.getWindowManagerService()
@@ -359,6 +391,9 @@
             mContext.getSystemService(DisplayManager.class).registerDisplayListener(this,
                     mContext.getMainThreadHandler());
             ActivityManagerWrapper.getInstance().registerTaskStackListener(mTaskStackListener);
+            DeviceConfig.addOnPropertiesChangedListener(DeviceConfig.NAMESPACE_SYSTEMUI,
+                    runnable -> (mContext.getMainThreadHandler()).post(runnable),
+                    mOnPropertiesChangedListener);
 
             try {
                 WindowManagerGlobal.getWindowManagerService()
@@ -379,6 +414,8 @@
             mPluginManager.addPluginListener(
                     this, NavigationEdgeBackPlugin.class, /*allowMultiple=*/ false);
         }
+        // Update the ML model resources.
+        updateMLModelState();
     }
 
     @Override
@@ -431,27 +468,88 @@
         }
     }
 
+    private void updateMLModelState() {
+        boolean newState = mIsEnabled && DeviceConfig.getBoolean(DeviceConfig.NAMESPACE_SYSTEMUI,
+                SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL, false);
+
+        if (newState == mUseMLModel) {
+            return;
+        }
+
+        if (newState) {
+            mBackGestureTfClassifierProvider = SystemUIFactory.getInstance()
+                    .createBackGestureTfClassifierProvider(mContext.getAssets());
+            mMLModelThreshold = DeviceConfig.getFloat(DeviceConfig.NAMESPACE_SYSTEMUI,
+                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD, 0.9f);
+            if (mBackGestureTfClassifierProvider.isActive()) {
+                mVocab = mBackGestureTfClassifierProvider.loadVocab(mContext.getAssets());
+                mUseMLModel = true;
+                return;
+            }
+        }
+
+        mUseMLModel = false;
+        if (mBackGestureTfClassifierProvider != null) {
+            mBackGestureTfClassifierProvider.release();
+            mBackGestureTfClassifierProvider = null;
+        }
+    }
+
+    private float getBackGesturePredictionsCategory(int x, int y) {
+        if (!mVocab.containsKey(mPackageName)) {
+            return -1;
+        }
+
+        int distanceFromEdge;
+        int location;
+        if (x <= mDisplaySize.x / 2.0) {
+            location = 1;  // left
+            distanceFromEdge = x;
+        } else {
+            location = 2;  // right
+            distanceFromEdge = mDisplaySize.x - x;
+        }
+
+        Object[] featuresVector = {
+            new long[]{(long) mDisplaySize.x},
+            new long[]{(long) distanceFromEdge},
+            new long[]{(long) location},
+            new long[]{(long) mVocab.get(mPackageName)},
+            new long[]{(long) y},
+        };
+
+        final float results = mBackGestureTfClassifierProvider.predict(featuresVector);
+        if (results == -1) return -1;
+
+        return results >= mMLModelThreshold ? 1 : 0;
+    }
+
     private boolean isWithinTouchRegion(int x, int y) {
-        // Disallow if we are in the bottom gesture area
-        if (y >= (mDisplaySize.y - mBottomGestureHeight)) {
-            return false;
-        }
+        boolean withinRange = false;
+        float results = -1;
 
-        // If the point is way too far (twice the margin), it is
-        // not interesting to us for logging purposes, nor we
-        // should process it.  Simply return false and keep
-        // mLogGesture = false.
-        if (x > 2 * (mEdgeWidthLeft + mLeftInset)
-                && x < (mDisplaySize.x - 2 * (mEdgeWidthRight + mRightInset))) {
-            return false;
+        if (mUseMLModel &&  (results = getBackGesturePredictionsCategory(x, y)) != -1) {
+            withinRange = results == 1 ? true : false;
+        } else {
+            // Disallow if we are in the bottom gesture area
+            if (y >= (mDisplaySize.y - mBottomGestureHeight)) {
+                return false;
+            }
+            // If the point is way too far (twice the margin), it is
+            // not interesting to us for logging purposes, nor we
+            // should process it.  Simply return false and keep
+            // mLogGesture = false.
+            if (x > 2 * (mEdgeWidthLeft + mLeftInset)
+                    && x < (mDisplaySize.x - 2 * (mEdgeWidthRight + mRightInset))) {
+                return false;
+            }
+            // Denotes whether we should proceed with the gesture.
+            // Even if it is false, we may want to log it assuming
+            // it is not invalid due to exclusion.
+            withinRange = x <= mEdgeWidthLeft + mLeftInset
+                    || x >= (mDisplaySize.x - mEdgeWidthRight - mRightInset);
         }
 
-        // Denotes whether we should proceed with the gesture.
-        // Even if it is false, we may want to log it assuming
-        // it is not invalid due to exclusion.
-        boolean withinRange = x <= mEdgeWidthLeft + mLeftInset
-                || x >= (mDisplaySize.x - mEdgeWidthRight - mRightInset);
-
         // Always allow if the user is in a transient sticky immersive state
         if (mIsNavBarShownTransiently) {
             mLogGesture = true;
@@ -648,6 +746,11 @@
         ActivityManager.RunningTaskInfo runningTask =
                 ActivityManagerWrapper.getInstance().getRunningTask();
         ComponentName topActivity = runningTask == null ? null : runningTask.topActivity;
+        if (topActivity != null) {
+            mPackageName = topActivity.getPackageName();
+        } else {
+            mPackageName = "_UNKNOWN";
+        }
         return topActivity != null && mGestureBlockingActivities.contains(topActivity);
     }