Merge internal changes manually am: 08374c5630

Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1449198

Change-Id: Iff96e0f28d994b82f21706f445d321318a58cb03
diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
index 0ce2fe8..3811222 100644
--- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
+++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
@@ -30,6 +30,7 @@
 import com.android.nn.benchmark.core.BenchmarkResult;
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
 
 import java.util.List;
 import java.util.Random;
@@ -77,7 +78,7 @@
 
     public void doBenchmark() {
         mProcessor = new Processor(this, this, randomModelList());
-        mProcessor.setUseNNApi(true);
+        mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
         mProcessor.setToggleLong(true);
         mProcessor.setMaxRunIterations(1);
         processorRunner.submit(mProcessor);
diff --git a/jni/benchmark_jni.cpp b/jni/benchmark_jni.cpp
index a1f8097..21bf2d4 100644
--- a/jni/benchmark_jni.cpp
+++ b/jni/benchmark_jni.cpp
@@ -28,7 +28,35 @@
 #include <android/log.h>
 #include <android/sharedmem.h>
 #include <sys/mman.h>
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
 
+extern "C" JNIEXPORT jboolean JNICALL
+Java_com_android_nn_benchmark_core_NNTestBase_hasNnApiDevice(
+    JNIEnv *env, jobject /* this */, jstring _nnApiDeviceName) {
+  bool foundDevice = false;
+  const char *nnApiDeviceName =
+      _nnApiDeviceName == NULL ? NULL
+                               : env->GetStringUTFChars(_nnApiDeviceName, NULL);
+  if (nnApiDeviceName != NULL) {
+    std::string device_name(nnApiDeviceName);
+    uint32_t numDevices = 0;
+    NnApiImplementation()->ANeuralNetworks_getDeviceCount(&numDevices);
+
+    for (uint32_t i = 0; i < numDevices; i++) {
+      ANeuralNetworksDevice *device = nullptr;
+      const char *buffer = nullptr;
+      NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
+      NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
+      if (device_name == buffer) {
+        foundDevice = true;
+        break;
+      }
+    }
+    env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
+  }
+
+  return foundDevice;
+}
 
 extern "C"
 JNIEXPORT jlong
@@ -37,7 +65,7 @@
         JNIEnv *env,
         jobject /* this */,
         jstring _modelFileName,
-        jboolean _useNnApi,
+        jint _tfliteBackend,
         jboolean _enableIntermediateTensorsDump,
         jstring _nnApiDeviceName,
         jboolean _mmapModel,
@@ -53,14 +81,14 @@
             : env->GetStringUTFChars(_nnApiCacheDir, NULL);
     int nnapiErrno = 0;
     void *handle = BenchmarkModel::create(
-        modelFileName, _useNnApi, _enableIntermediateTensorsDump, &nnapiErrno,
+        modelFileName, _tfliteBackend, _enableIntermediateTensorsDump, &nnapiErrno,
         nnApiDeviceName, _mmapModel, nnApiCacheDir);
     env->ReleaseStringUTFChars(_modelFileName, modelFileName);
     if (_nnApiDeviceName != NULL) {
         env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
     }
 
-    if (_useNnApi && nnapiErrno != 0) {
+    if (_tfliteBackend == TFLITE_NNAPI && nnapiErrno != 0) {
       jclass nnapiFailureClass = env->FindClass(
           "com/android/nn/benchmark/core/NnApiDelegationFailure");
       jmethodID constructor =
@@ -197,7 +225,6 @@
                     jobject creator = env->GetObjectField(inout, inout_inputCreator);
                     if (creator == nullptr) { return false; }
                     env->CallVoidMethod(creator, createInput_method, byteBuffer);
-                    env->DeleteLocalRef(byteBuffer);
                     if (env->ExceptionCheck()) { return false; }
                     return true;
                 };
diff --git a/jni/run_tflite.cpp b/jni/run_tflite.cpp
index 8362af0..dfc3c82 100644
--- a/jni/run_tflite.cpp
+++ b/jni/run_tflite.cpp
@@ -28,9 +28,10 @@
 #include <fstream>
 
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
-#include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
 
+#include "tensorflow/lite/kernels/register.h"
+
 #define LOG_TAG "NN_BENCHMARK"
 
 #define FATAL(fmt, ...)                                                  \
@@ -69,12 +70,12 @@
 
 }  // namespace
 
-BenchmarkModel* BenchmarkModel::create(const char* modelfile, bool use_nnapi,
+BenchmarkModel* BenchmarkModel::create(const char* modelfile, int tfliteBackend,
                                        bool enable_intermediate_tensors_dump, int* nnapiErrno,
                                        const char* nnapi_device_name, bool mmapModel,
                                        const char* nnapi_cache_dir) {
   BenchmarkModel* model = new BenchmarkModel();
-  if (!model->init(modelfile, use_nnapi, enable_intermediate_tensors_dump, nnapiErrno,
+  if (!model->init(modelfile, tfliteBackend, enable_intermediate_tensors_dump, nnapiErrno,
                    nnapi_device_name, mmapModel, nnapi_cache_dir)) {
     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to init model %s", modelfile);
     delete model;
@@ -83,12 +84,13 @@
   return model;
 }
 
-bool BenchmarkModel::init(const char* modelfile, bool use_nnapi,
+bool BenchmarkModel::init(const char* modelfile, int tfliteBackend,
                           bool enable_intermediate_tensors_dump, int* nnapiErrno,
                           const char* nnapi_device_name, bool mmapModel,
                           const char* nnapi_cache_dir) {
+  __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "BenchmarkModel %s",
+                      modelfile);
   mModelFile = modelfile;
-  mUseNnApi = use_nnapi;
   if (nnapi_cache_dir) {
     mCacheDir = nnapi_cache_dir;
   }
@@ -135,26 +137,56 @@
   // Allow Fp16 precision for all models
   mTfliteInterpreter->SetAllowFp16PrecisionForFp32(true);
 
-  if (use_nnapi) {
-    tflite::StatefulNnApiDelegate::Options nnapi_options;
-    nnapi_options.accelerator_name = nnapi_device_name;
-    mTfliteNnapiDelegate = std::make_unique<tflite::StatefulNnApiDelegate>(nnapi_options);
-    int delegationStatus = mTfliteInterpreter->ModifyGraphWithDelegate(mTfliteNnapiDelegate.get());
-    *nnapiErrno = mTfliteNnapiDelegate->GetNnApiErrno();
-    if (delegationStatus != kTfLiteOk ||
-        *nnapiErrno != ANEURALNETWORKS_NO_ERROR) {
-      __android_log_print(
-          ANDROID_LOG_ERROR, LOG_TAG,
-          "Failed to initialize NNAPI Delegate for model %s, nnapi_errno is %d",
-          modelfile, *nnapiErrno);
-      return false;
-    }
+  mTfliteBackend = tfliteBackend;
+  switch (mTfliteBackend) {
+    case TFLITE_NNAPI: {
+      tflite::StatefulNnApiDelegate::Options nnapi_options;
+      nnapi_options.accelerator_name = nnapi_device_name;
+      mTfliteNnapiDelegate = std::make_unique<tflite::StatefulNnApiDelegate>(nnapi_options);
+      int delegationStatus = mTfliteInterpreter->ModifyGraphWithDelegate(mTfliteNnapiDelegate.get());
+      *nnapiErrno = mTfliteNnapiDelegate->GetNnApiErrno();
+      if (delegationStatus != kTfLiteOk ||
+          *nnapiErrno != ANEURALNETWORKS_NO_ERROR) {
+        __android_log_print(
+            ANDROID_LOG_ERROR, LOG_TAG,
+            "Failed to initialize NNAPI Delegate for model %s, nnapi_errno is %d",
+            modelfile, *nnapiErrno);
+        return false;
+      }
+    } break;
+    case TFLITE_GPU: {
+#if defined(NN_BENCHMARK_ENABLE_GPU)
+      mGpuDelegate = TfLiteGpuDelegateV2Create(/*default options=*/nullptr);
+      if (mTfliteInterpreter->ModifyGraphWithDelegate(mGpuDelegate) !=
+          kTfLiteOk) {
+        __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
+                            "Failed to initialize GPU Delegate");
+        return false;
+      }
+#else  // !defined(NN_BENCHMARK_ENABLE_GPU)
+        __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
+                            "GPU delegate requested but not enabled with "
+                            "NN_BENCHMARK_ENABLE_GPU");
+        return false;
+#endif  // defined(NN_BENCHMARK_ENABLE_GPU)
+    } break;
+    default:
+      break;
   }
   return true;
 }
 
-BenchmarkModel::BenchmarkModel() {}
-BenchmarkModel::~BenchmarkModel() {}
+BenchmarkModel::~BenchmarkModel() {
+  switch (mTfliteBackend) {
+    case TFLITE_GPU: {
+#if defined(NN_BENCHMARK_ENABLE_GPU)  // !defined(NN_BENCHMARK_ENABLE_GPU)
+      TfLiteGpuDelegateV2Delete(mGpuDelegate);
+#endif  // !defined(NN_BENCHMARK_ENABLE_GPU)
+    } break;
+    default:
+      break;
+  }
+}
 
 bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
   int input = mTfliteInterpreter->inputs()[0];
@@ -362,7 +394,7 @@
   // Allow Fp16 precision for all models
   interpreter->SetAllowFp16PrecisionForFp32(true);
 
-  if (mUseNnApi) {
+  if (mTfliteBackend == TFLITE_NNAPI) {
     tflite::StatefulNnApiDelegate::Options nnapi_options;
     nnapi_options.accelerator_name = mNnApiDeviceName.empty() ? nullptr : mNnApiDeviceName.c_str();
     if (cacheDir) {
diff --git a/jni/run_tflite.h b/jni/run_tflite.h
index 9e0f086..71c7747 100644
--- a/jni/run_tflite.h
+++ b/jni/run_tflite.h
@@ -17,9 +17,12 @@
 #ifndef COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
 #define COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
 
+#include "tensorflow/lite/delegates/gpu/delegate.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/model.h"
 
+#include <memory>
 #include <unistd.h>
 #include <vector>
 
@@ -83,11 +86,16 @@
   PREPARE_FROM_CACHE,
 };
 
+/** TFLite backend. */
+constexpr int TFLITE_CPU = 0;
+constexpr int TFLITE_NNAPI = 1;
+constexpr int TFLITE_GPU = 2;
+
 class BenchmarkModel {
  public:
   ~BenchmarkModel();
 
-  static BenchmarkModel* create(const char* modelfile, bool use_nnapi,
+  static BenchmarkModel* create(const char* modelfile, int tfliteBackend,
                                 bool enable_intermediate_tensors_dump,
                                 int* nnapiErrno, const char* nnapi_device_name,
                                 bool mmapModel, const char* nnapi_cache_dir);
@@ -109,8 +117,8 @@
                      const std::vector<InferenceInOutSequence>& inOutData);
 
  private:
-  BenchmarkModel();
-  bool init(const char* modelfile, bool use_nnapi,
+  BenchmarkModel() = default;
+  bool init(const char* modelfile, int tfliteBackend,
             bool enable_intermediate_tensors_dump,
             int* nnapiErrno, const char* nnapi_device_name,
             /* flag to choose between memory mapping the model and initializing
@@ -139,9 +147,12 @@
 
   // Parameters for compilation
   std::string mModelFile;
-  bool mUseNnApi;
   std::optional<std::string> mCacheDir;
   std::string mNnApiDeviceName;
+#if defined(NN_BENCHMARK_ENABLE_GPU)
+  TfLiteDelegate* mGpuDelegate;
+#endif  // defined(NN_BENCHMARK_ENABLE_GPU)
+  int mTfliteBackend;
 };
 
 #endif  // COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index a87760d..24878d8 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -27,6 +27,7 @@
 import com.android.nn.benchmark.core.BenchmarkResult;
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import com.android.nn.benchmark.core.TfLiteBackend;
 import java.io.IOException;
 import java.time.Duration;
 import java.util.concurrent.ExecutorService;
@@ -59,7 +60,7 @@
     }
 
     public void setUseNNApi(boolean useNNApi) {
-        mProcessor.setUseNNApi(useNNApi);
+        mProcessor.setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
     }
 
     public void setCompleteInputSet(boolean completeInputSet) {
@@ -124,7 +125,7 @@
             mProcessor = new Processor(this, this, mTestList);
             mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
             mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
-            mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
+            mProcessor.setTfLiteBackend(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false) ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
             mProcessor.setMaxRunIterations(i.getIntExtra(EXTRA_MAX_ITERATIONS, 0));
             executorService.submit(mProcessor);
         } else {
diff --git a/src/com/android/nn/benchmark/core/BenchmarkResult.java b/src/com/android/nn/benchmark/core/BenchmarkResult.java
index 8df486e..a765a39 100644
--- a/src/com/android/nn/benchmark/core/BenchmarkResult.java
+++ b/src/com/android/nn/benchmark/core/BenchmarkResult.java
@@ -27,6 +27,7 @@
 import java.util.List;
 
 public class BenchmarkResult implements Parcelable {
+    // Used by CTS tests.
     public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
     public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
 
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index aeda0b2..2e2e5e3 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -53,10 +53,11 @@
      * @return False if any error occurred, true otherwise
      */
     private static native boolean getAcceleratorNames(List<String> resultList);
+    public static native boolean hasNnApiDevice(String nnApiDeviceName);
 
     private synchronized native long initModel(
             String modelFileName,
-            boolean useNNApi,
+            int tfliteBackend,
             boolean enableIntermediateTensorsDump,
             String nnApiDeviceName,
             boolean mmapModel,
@@ -119,7 +120,7 @@
     private final EvaluatorConfig mEvaluatorConfig;
     private EvaluatorInterface mEvaluator;
     private boolean mHasGoldenOutputs;
-    private boolean mUseNNApi = false;
+    private TfLiteBackend mTfLiteBackend;
     private boolean mEnableIntermediateTensorsDump = false;
     private final int mMinSdkVersion;
     private Optional<String> mNNApiDeviceName = Optional.empty();
@@ -152,12 +153,8 @@
         mSampleResults = false;
     }
 
-    public void useNNApi() {
-        useNNApi(true);
-    }
-
-    public void useNNApi(boolean value) {
-        mUseNNApi = value;
+    public void setTfLiteBackend(TfLiteBackend tfLiteBackend) {
+        mTfLiteBackend = tfLiteBackend;
     }
 
     public void enableIntermediateTensorsDump() {
@@ -168,8 +165,12 @@
         mEnableIntermediateTensorsDump = value;
     }
 
+    public void useNNApi() {
+      setTfLiteBackend(TfLiteBackend.NNAPI);
+    }
+
     public void setNNApiDeviceName(String value) {
-        if (!mUseNNApi) {
+        if (mTfLiteBackend != TfLiteBackend.NNAPI) {
             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
         }
         mNNApiDeviceName = Optional.ofNullable(value);
@@ -187,13 +188,15 @@
         mTemporaryModelFilePath = copyAssetToFile();
         String nnApiCacheDir = mContext.getCodeCacheDir().toString();
         mModelHandle = initModel(
-                mTemporaryModelFilePath, mUseNNApi, mEnableIntermediateTensorsDump,
+                mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump,
                 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir);
         if (mModelHandle == 0) {
             Log.e(TAG, "Failed to init the model");
             return false;
         }
-        resizeInputTensors(mModelHandle, mInputShape);
+        if (!resizeInputTensors(mModelHandle, mInputShape)) {
+            return false;
+        }
 
         if (mEvaluatorConfig != null) {
             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
@@ -310,30 +313,31 @@
         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
     }
 
-    /** Run through whole input set (once or mutliple times). */
+    /** Run through whole input set (once or multiple times). */
     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
-            int setRepeat,
+            int minInferences,
             float timeoutSec)
             throws IOException, BenchmarkException {
         int flags = getDefaultFlags();
         List<InferenceInOutSequence> ios = getInputOutputAssets();
-        int totalSequenceInferencesCount = ios.size() * setRepeat;
-        int extpectedResults = 0;
+        int setInferences = 0;
         for (InferenceInOutSequence iosSeq : ios) {
-            extpectedResults += iosSeq.size();
+            setInferences += iosSeq.size();
         }
-        extpectedResults *= setRepeat;
+        int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil.
+        int totalSequenceInferencesCount = ios.size() * setRepeat;
+        int expectedResults = setInferences * setRepeat;
 
         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
                         flags);
-        if (result.second.size() != extpectedResults) {
+        if (result.second.size() != expectedResults) {
             // We reached a timeout or failed to evaluate whole set for other reason, abort.
             @SuppressLint("DefaultLocale")
             final String errorMsg = String.format(
                     "Failed to evaluate complete input set, in %f seconds expected: %d, received:"
                             + " %d",
-                    timeoutSec, extpectedResults, result.second.size());
+                    timeoutSec, expectedResults, result.second.size());
             Log.w(TAG, errorMsg);
             throw new IllegalStateException(errorMsg);
         }
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 1cf605c..f8aaa04 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -53,7 +53,7 @@
 
     private Processor.Callback mCallback;
 
-    private boolean mUseNNApi;
+    private TfLiteBackend mBackend;
     private boolean mMmapModel;
     private boolean mCompleteInputSet;
     private boolean mToggleLong;
@@ -82,10 +82,15 @@
         mRunModelCompilationOnly = false;
         mMaxRunIterations = 0;
         mBenchmarkCompilationCaching = false;
+        mBackend = TfLiteBackend.CPU;
     }
 
     public void setUseNNApi(boolean useNNApi) {
-        mUseNNApi = useNNApi;
+        setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
+    }
+
+    public void setTfLiteBackend(TfLiteBackend backend) {
+        mBackend = backend;
     }
 
     public void setCompleteInputSet(boolean completeInputSet) {
@@ -156,7 +161,7 @@
     public static boolean isTestModelSupportedByAccelerator(Context context,
             TestModels.TestModelEntry testModelEntry, String acceleratorName)
             throws NnApiDelegationFailure {
-        try (NNTestBase tb = testModelEntry.createNNTestBase(/*useNnnapi=*/ true,
+        try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI,
                 /*enableIntermediateTensorsDump=*/false,
                 /*mmapModel=*/ false)) {
             tb.setNNApiDeviceName(acceleratorName);
@@ -183,9 +188,9 @@
             // Make sure we don't leak memory.
             oldTestBase.destroy();
         }
-        NNTestBase tb = t.createNNTestBase(mUseNNApi, /*enableIntermediateTensorsDump=*/false,
+        NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false,
                 mMmapModel);
-        if (mUseNNApi) {
+        if (mBackend == TfLiteBackend.NNAPI) {
             tb.setNNApiDeviceName(mAcceleratorName);
         }
         if (!tb.setupModel(mContext)) {
@@ -212,9 +217,7 @@
             }
             return BenchmarkResult.fromInferenceResults(
                     mTest.getTestInfo(),
-                    mUseNNApi
-                            ? BenchmarkResult.BACKEND_TFLITE_NNAPI
-                            : BenchmarkResult.BACKEND_TFLITE_CPU,
+                    mBackend.toString(),
                     results.first,
                     results.second,
                     mTest.getEvaluator());
@@ -274,7 +277,7 @@
         }
 
         // Compilation benchmark
-        if (mUseNNApi && mBenchmarkCompilationCaching) {
+        if (mBenchmarkCompilationCaching) {
             runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
                     mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
         }
@@ -371,9 +374,8 @@
 
             if (mRunModelCompilationOnly) {
                 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
-                        mUseNNApi
-                                ? BenchmarkResult.BACKEND_TFLITE_NNAPI
-                                : BenchmarkResult.BACKEND_TFLITE_CPU, Collections.emptyList(),
+                        mBackend.toString(),
+                        Collections.emptyList(),
                         Collections.emptyList(), null);
             } else {
                 // Run the test
diff --git a/src/com/android/nn/benchmark/core/TestModels.java b/src/com/android/nn/benchmark/core/TestModels.java
index cd386ba..9eb90f5 100644
--- a/src/com/android/nn/benchmark/core/TestModels.java
+++ b/src/com/android/nn/benchmark/core/TestModels.java
@@ -76,14 +76,20 @@
                     mEvaluator, mMinSdkVersion);
         }
 
-        public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) {
-            return createNNTestBase(useNNApi, enableIntermediateTensorsDump, /*mmapModel=*/false);
+        public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) {
+            return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
         }
 
-        public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump,
+        // Used by CTS tests.
+        public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) {
+            TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
+            return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
+        }
+
+        public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump,
                 boolean mmapModel) {
             NNTestBase test = createNNTestBase();
-            test.useNNApi(useNNApi);
+            test.setTfLiteBackend(tfLiteBackend);
             test.enableIntermediateTensorsDump(enableIntermediateTensorsDump);
             test.setMmapModel(mmapModel);
             return test;
diff --git a/src/com/android/nn/benchmark/core/TfLiteBackend.java b/src/com/android/nn/benchmark/core/TfLiteBackend.java
new file mode 100644
index 0000000..cb385c2
--- /dev/null
+++ b/src/com/android/nn/benchmark/core/TfLiteBackend.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.benchmark.core;
+
+public enum TfLiteBackend {
+  // The order of the values should be consistent with the constants
+  // in jni/run_tflite.h
+  CPU("TFLite_CPU"),
+  NNAPI("TFLite_NNAPI"),
+  GPU("TFLite_GPU");
+
+  private final String mName;
+
+  private TfLiteBackend(String name) { mName = name; }
+
+  public static TfLiteBackend parseString(String s) {
+    for (TfLiteBackend backend : TfLiteBackend.values()) {
+      if (backend.toString().equals(s)) {
+        return backend;
+      }
+    }
+    return CPU;
+  }
+
+  @Override
+  public String toString() { return mName; }
+}
diff --git a/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
index 59ea610..5c4197c 100644
--- a/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
+++ b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
@@ -23,6 +23,7 @@
 import com.android.nn.benchmark.core.NNTestBase;
 import com.android.nn.benchmark.core.TestModels;
 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import com.android.nn.benchmark.core.TfLiteBackend;
 
 import java.io.File;
 
@@ -44,6 +45,7 @@
     public static final String EXTRA_MODEL_NAME = "modelName";
     public static final String EXTRA_INPUT_ASSET_INDEX = "inputAssetIndex";
     public static final String EXTRA_INPUT_ASSET_SIZE = "inputAssetSize";
+    public static final String EXTRA_TFLITE_BACKEND = "tfLiteBackend";
     public static final String DUMP_DIR = "intermediate";
     public static final String CPU_DIR = "cpu";
     public static final String NNAPI_DIR = "nnapi";
@@ -89,9 +91,10 @@
                 // Run in CPU and NNAPI mode
                 for (final boolean useNNAPI : new boolean[]{false, true}) {
                     String useNNAPIDir = useNNAPI ? NNAPI_DIR : CPU_DIR;
+                    TfLiteBackend backend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
                     TestModelEntry modelEntry = TestModels.getModelByName(modelName);
                     try (NNTestBase testBase = modelEntry.createNNTestBase(
-                            useNNAPI, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false)) {
+                            backend, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false)) {
                         testBase.setupModel(this);
                         File outputDir = new File(getFilesDir() + "/" + DUMP_DIR +
                                 "/" + modelName, useNNAPIDir);
diff --git a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
index e4da19a..68e35d9 100644
--- a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
+++ b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
@@ -27,6 +27,7 @@
 import com.android.nn.benchmark.core.NnApiDelegationFailure;
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -122,7 +123,7 @@
                         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
                         }
                     }, new int[0]);
-            mProcessor.setUseNNApi(true);
+            mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
             mProcessor.setCompleteInputSet(false);
             mProcessor.setNnApiAcceleratorName(acceleratorName);
             mTestModelEntry = testModelEntry;
@@ -152,4 +153,4 @@
             return true;
         }
     }
-}
\ No newline at end of file
+}
diff --git a/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
index 038d554..c203441 100644
--- a/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
+++ b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
@@ -25,6 +25,7 @@
 import com.android.nn.benchmark.core.BenchmarkResult;
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
 import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport;
 import com.android.nn.crashtest.core.CrashTest;
 import com.android.nn.crashtest.core.CrashTestCoordinator;
@@ -226,7 +227,7 @@
             final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) {
         return () -> {
             Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]);
-            benchmarkProcessor.setUseNNApi(true);
+            benchmarkProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
             benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName);
             if (start != null) {
                 start.countDown();
@@ -252,7 +253,7 @@
             mStart = start;
             mTestModelEntry = testModelEntry;
             mProcessor = new Processor(context, mNoOpCallback, new int[0]);
-            mProcessor.setUseNNApi(true);
+            mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
             mProcessor.setNnApiAcceleratorName(acceleratorName);
             mProcessor.setRunModelCompilationOnly(true);
             mRun = true;
diff --git a/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java b/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
index 5fe204b..9f01128 100644
--- a/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
+++ b/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
@@ -26,6 +26,7 @@
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.crashtest.core.CrashTest;
 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
+import com.android.nn.benchmark.core.TfLiteBackend;
 
 import java.time.Duration;
 import java.util.ArrayList;
@@ -131,7 +132,7 @@
             public void onStatusUpdate(int testNumber, int numTests, String modelName) {
             }
         }, testList);
-        result.setUseNNApi(true);
+        result.setTfLiteBackend(TfLiteBackend.NNAPI);
         result.setCompleteInputSet(false);
         result.setNnApiAcceleratorName(mAcceleratorName);
         result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);