TF-TRT: Support GatherV2 op

Convert 'indices' input to a GatherV2 op into a constant
layer if it is a constant.

Modify the unit tests to generate constant indices.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 456696c..8d72b64 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -4767,7 +4767,7 @@
   // option for an input to be either tensor or weight.
   TF_RETURN_IF_ERROR(
       CheckInputsWeights(*params, {{"params", TrtInputArg::kBoth},
-                                   {"indices", TrtInputArg::kTensor},
+                                   {"indices", TrtInputArg::kBoth},
                                    {"axis", TrtInputArg::kWeight}}));
 
   const auto& params_input = inputs.at(0);
@@ -4794,20 +4794,24 @@
     return errors::Unimplemented(
         "The input axis must be zero when params is a weight.");
   }
-  if (params->use_implicit_batch && params_input.is_tensor() &&
-      indices_input.batch_size() != 1) {
+  if (params->use_implicit_batch &&
+      (params_input.is_tensor() == indices_input.is_tensor()) &&
+     (indices_input.batch_size() != 1 || params_input.batch_size() != 1)) {
     return errors::Unimplemented(
-        "Indices must have a batch size of 1 when params is a tensor.");
+        "Params and indices must have a batch size of 1 when params and indices"
+        " are both tensors or both constants.");
   }
+
+  auto get_rank = [params](const auto& input) {
+    return input.GetTrtDims().nbDims +
+           (params->use_implicit_batch && input.is_tensor() ? 1 : 0);
+  };
   // Both input are tensors, and the TF gather result will have rank:
   // (params.nbDims + 1) + (indices.nbDims + 1) - 1,
   // where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
   // the TF rank so we don't have to add + 1.
-  const int params_tf_rank =
-      params_input.GetTrtDims().nbDims +
-      (params->use_implicit_batch && params_input.is_tensor() ? 1 : 0);
-  const int indices_tf_rank =
-      indices_input.GetTrtDims().nbDims + (params->use_implicit_batch ? 1 : 0);
+  const int params_tf_rank = get_rank(params_input);
+  const int indices_tf_rank = get_rank(indices_input);
   const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
   if (tf_gather_output_rank >
       nvinfer1::Dims::MAX_DIMS + (params->use_implicit_batch ? 1 : 0)) {
@@ -4817,14 +4821,25 @@
   }
   if (params->validation_only) return Status::OK();
 
-  // Convert params to tensor is it is a weight.
-  ITensorProxyPtr params_tensor = nullptr;
-  if (params_input.is_weights()) {
-    params_tensor = params->converter->CreateConstantLayer(
-        params_input.weights(), params_input.GetTrtDims());
-  } else {
-    params_tensor = params_input.tensor();
-  }
+
+  // Convert input or indices to tensor if it is a constant.
+  auto populate_tensor =
+    [params](const auto& input) -> ITensorProxyPtr {
+
+    ITensorProxyPtr result_tensor = nullptr;
+
+    if (input.is_weights()) {
+      result_tensor = params->converter->CreateConstantLayer(
+          input.weights(), input.GetTrtDims());
+    } else {
+      result_tensor = input.tensor();
+    }
+
+    return result_tensor;
+  };
+
+  ITensorProxyPtr params_tensor = populate_tensor(params_input);
+  ITensorProxyPtr indices_tensor = populate_tensor(indices_input);
 
   // Note on how IGatherLayer works: if both the data and indices tensors have
   // a batch size dimension of size N, it performs:
@@ -4832,28 +4847,35 @@
   //   output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
   //       data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
   nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
-      *params_tensor->trt_tensor(), *indices_input.tensor()->trt_tensor(),
+      *params_tensor->trt_tensor(), *indices_tensor->trt_tensor(),
       trt_axis);
   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
   params->converter->SetLayerName(layer, node_def);
 
   ITensorProxyPtr output_tensor = layer->getOutput(0);
   nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
-  // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
-  // and the other is for the output dimension that is squeezed by IGatherLayer
-  // because of the implicit batch dim in the indices (see the above note).
-  const int expected_trt_output_rank =
-      tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
-  if (params->use_implicit_batch &&
-      trt_gather_output_dims.nbDims != expected_trt_output_rank) {
-    return errors::Internal(
-        "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
-        expected_trt_output_rank,
-        ", actual nbDims: ", trt_gather_output_dims.nbDims);
+
+  if (params->use_implicit_batch) {
+    // Note for the "- 2": one is for the output batch dim encapsulated by
+    // TF-TRT, and the other is for the output dimension that is squeezed by
+    // IGatherLayer because of the implicit batch dim in the indices (see the
+    // above note).
+    const int expected_trt_output_rank = tf_gather_output_rank -
+                                        (params_input.is_tensor() ? 1 : 0) -
+                                        (indices_input.is_tensor() ? 1 : 0);
+
+    if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
+      return errors::Internal(
+          "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
+          expected_trt_output_rank,
+          ", actual nbDims: ", trt_gather_output_dims.nbDims);
+    }
   }
   // Reshape the output so after adding the implicit batch dim it'll match the
   // output shape of TF GatherV2.
-  if (params->use_implicit_batch && params_input.is_tensor()) {
+  if (params->use_implicit_batch &&
+      params_input.is_tensor() &&
+      indices_input.is_tensor()) {
     for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
       trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
     }
@@ -4866,6 +4888,26 @@
         /*validation_only=*/false, &output_tensor, node_def));
   }
 
+  // When input and indices are both constants, for the supported cases, reshape
+  // output so that after removing the implicit batch dim it will match the
+  // output shape of TF GatherV2 op.
+  if (params->use_implicit_batch &&
+      params_input.is_weights() &&
+      indices_input.is_weights()) {
+    for (int i = trt_axis; i < trt_gather_output_dims.nbDims - 1; ++i) {
+      trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i + 1];
+    }
+
+    // Squeeze the implicit batch dimension out. Note: this works only
+    // when batch size for both inputs and indices are 1.
+    --trt_gather_output_dims.nbDims;
+
+    TF_RETURN_IF_ERROR(PrepareTensorForShape(
+        params->converter, TRT_TensorOrWeights(output_tensor),
+        trt_gather_output_dims,
+        /*validation_only=*/false, &output_tensor, node_def));
+  }
+
   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
   return Status::OK();
 }
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 10fe1ec..12db165 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -5755,74 +5755,97 @@
     std::vector<int> expected_output_shape;
     std::vector<int> expected_output;
     bool params_is_tensor;
-    Status status;
+    bool indices_is_tensor;
+    Status conversion_status;
     Status runtime_status;
     Status add_index_status;
   };
 
   // Input is the same {1, 2, 3, 4, 5, 6} for all cases.
   const std::vector<int> params_input = {1, 2, 3, 4, 5, 6};
+
   std::vector<TestParams> test_params = {
       // Axis is batch dimension, should fail in implicit batch mode.
-      TestParams{/*params_shape=*/{2, 1, 1, 3},
-                 /*indices_shape=*/{2},
-                 /*indices=*/{1, 0},
-                 /*axis=*/0,
-                 /*expected_output_shape=*/{2, 1, 1, 3},
-                 /*expected_output=*/{4, 5, 6, 1, 2, 3},
-                 /*params_is_tensor=*/true,
-                 trt_mode_ == TrtTestMode::kImplicitBatch
-                     ? Status{error::UNIMPLEMENTED,
-                              "TensorRT does not allow manipulation of the"
-                              " batch dimension, at my_gather"}
-                     : Status::OK()},
-      // Batch size of indices is not 1 when params is a tensor.
-      TestParams{/*params_shape=*/{2, 1, 3},
-                 /*indices_shape=*/{2, 1},
-                 /*indices=*/{2, 0},
-                 /*axis=*/2,
-                 /*expected_output_shape=*/{2, 1, 2, 1},
-                 /*expected_output=*/{3, 1, 6, 4},
-                 /*params_is_tensor=*/true,
-                 trt_mode_ == TrtTestMode::kImplicitBatch
-                     ? Status{error::UNIMPLEMENTED,
-                              "Indices must have a batch size of 1 when params"
-                              " is a tensor."}
-                     : Status::OK()},
+      TestParams{
+          /*params_shape=*/{2, 1, 1, 3},
+          /*indices_shape=*/{2},
+          /*indices=*/{1, 0},
+          /*axis=*/0,
+          /*expected_output_shape=*/{2, 1, 1, 3},
+          /*expected_output=*/{4, 5, 6, 1, 2, 3},
+          /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "TensorRT does not allow "
+                       "manipulation of the batch dimension, at my_gather"}
+              : Status::OK()
+      },
+      // Batch size of indices is not 1 when params and indices are tensors.
+      TestParams{
+          /*params_shape=*/{2, 1, 3},
+          /*indices_shape=*/{2, 1},
+          /*indices=*/{2, 0},
+          /*axis=*/2,
+          /*expected_output_shape=*/{2, 1, 2, 1},
+          /*expected_output=*/{3, 1, 6, 4},
+          /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+              " batch size of 1 when params and indices are both tensors or both"
+              " constants."}
+              : Status::OK()
+      },
+      // Batch size of indices is not 1 when params is tensor and indices are
+      // constant.
+      TestParams{
+          /*params_shape=*/{2, 1, 3},
+          /*indices_shape=*/{2, 1},
+          /*indices=*/{2, 0},
+          /*axis=*/2,
+          /*expected_output_shape=*/{2, 1, 2, 1},
+          /*expected_output=*/{3, 1, 6, 4},
+          /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/false,
+          /*conversion_status=*/Status::OK()
+      },
       // Axis is not zero when params is a weight, should fail in implicit batch
       // mode.
-      TestParams{/*params_shape=*/{2, 1, 3},
-                 /*indices_shape=*/{2},
-                 /*indices=*/{1, 2},
-                 /*axis=*/2,
-                 /*expected_output_shape=*/{2, 1, 2},
-                 /*expected_output=*/{2, 3, 5, 6},
-                 /*params_is_tensor=*/false,
-                 trt_mode_ == TrtTestMode::kImplicitBatch
-                     ? Status{error::UNIMPLEMENTED,
-                              "The input axis must be zero when params is a"
-                              " weight."}
-                     : Status::OK()},
+      TestParams{
+          /*params_shape=*/{2, 1, 3},
+          /*indices_shape=*/{2},
+          /*indices=*/{1, 2},
+          /*axis=*/2,
+          /*expected_output_shape=*/{2, 1, 2},
+          /*expected_output=*/{2, 3, 5, 6},
+          /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/true,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "The input axis must be zero when "
+                       "params is a weight."}
+              : Status::OK()
+      },
       // Params with only batch dimension.
-      TestParams{/*params_shape=*/{6},
-                 /*indices_shape=*/{2},
-                 /*indices=*/{1, 3},
-                 /*axis=*/0,
-                 /*expected_output_shape=*/{2},
-                 /*expected_output=*/{2, 4},
-                 /*params_is_tensor=*/true,
-                 trt_mode_ == TrtTestMode::kImplicitBatch  // conversion_status
-                     ? Status{error::UNIMPLEMENTED,
-                              "TensorRT does not allow manipulation of the "
-                              "batch dimension, at my_gather"}
-                     : Status::OK(),
-                 Status::OK(),                             // runtime_status
-                 trt_mode_ == TrtTestMode::kImplicitBatch  // add_index_status
-                     ? Status{error::INVALID_ARGUMENT,
-                              "Batch size doesn't match for tensor indices: "
-                              "Provided batch size does not match converter "
-                              "batch size: 2 vs 6"}
-                     : Status::OK()},
+      TestParams{
+          /*params_shape=*/{6},
+          /*indices_shape=*/{2},
+          /*indices=*/{1, 3},
+          /*axis=*/0,
+          /*expected_output_shape=*/{2},
+          /*expected_output=*/{2, 4},
+          /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "TensorRT does not allow "
+                       "manipulation of the batch dimension, at my_gather"}
+              : Status::OK(),
+          /*runtime_status=*/Status::OK(),
+          /*add_index_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::INVALID_ARGUMENT, "Batch size doesn't match for "
+                       "tensor indices: Provided batch size does not match "
+                       "converter batch size: 2 vs 6"}
+              : Status::OK()
+      },
       // Vector indices, and output rank is rank(params).
       TestParams{
           /*params_shape=*/{1, 1, 2, 3},
@@ -5832,6 +5855,7 @@
           /*expected_output_shape=*/{1, 1, 2, 1},
           /*expected_output=*/{1, 4},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 1, 2, 3},
@@ -5841,6 +5865,7 @@
           /*expected_output_shape=*/{1, 1, 1, 3},
           /*expected_output=*/{4, 5, 6},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       // Indices with rank>1, and output rank is rank(params) + rank(indices) -
       // 1
@@ -5852,6 +5877,7 @@
           /*expected_output_shape=*/{1, 1, 2, 1, 1},
           /*expected_output=*/{1, 4},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 1, 2, 3},
@@ -5861,6 +5887,7 @@
           /*expected_output_shape=*/{1, 1, 2, 1, 1},
           /*expected_output=*/{2, 5},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 1, 2, 3},
@@ -5870,6 +5897,7 @@
           /*expected_output_shape=*/{1, 1, 2, 1, 1},
           /*expected_output=*/{3, 6},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 1, 2, 3},
@@ -5879,6 +5907,7 @@
           /*expected_output_shape=*/{1, 1, 2, 1, 3},
           /*expected_output=*/{3, 1, 2, 6, 4, 5},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 3, 2},
@@ -5888,6 +5917,7 @@
           /*expected_output_shape=*/{1, 3, 1, 2, 2},
           /*expected_output=*/{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5},
           /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{1, 2, 3},
@@ -5897,6 +5927,7 @@
           /*expected_output_shape=*/{1, 2, 3},
           /*expected_output=*/{1, 2, 3, 4, 5, 6},
           /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{3, 2},
@@ -5906,6 +5937,7 @@
           /*expected_output_shape=*/{1, 2, 2},
           /*expected_output=*/{1, 2, 3, 4},
           /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{2, 3},
@@ -5915,6 +5947,7 @@
           /*expected_output_shape=*/{1, 1, 2, 3},
           /*expected_output=*/{1, 2, 3, 4, 5, 6},
           /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/true,
       },
       TestParams{
           /*params_shape=*/{3, 2},
@@ -5924,22 +5957,75 @@
           /*expected_output_shape=*/{2, 2, 2},
           /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
           /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/true,
+      },
+      // Test cases in which indices constant
+      TestParams{
+          /*params_shape=*/{1, 1, 2, 3},
+          /*indices_shape=*/{1, 1},
+          /*indices=*/{0},
+          /*axis=*/3,
+          /*expected_output_shape=*/{1, 1, 2, 1, 1},
+          /*expected_output=*/{1, 4},
+          /*params_is_tensor=*/true,
+          /*indices_is_tensor=*/false,
+      },
+      // Test cases in which both input and indices constant
+      TestParams{
+          /*params_shape=*/{1, 2, 3},
+          /*indices_shape=*/{1},
+          /*indices=*/{0},
+          /*axis=*/0,
+          /*expected_output_shape=*/{1, 2, 3},
+          /*expected_output=*/{1, 2, 3, 4, 5, 6},
+          /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/false,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+              " batch size of 1 when params and indices are both tensors or both"
+              " constants."}
+              : Status::OK()
+
+      },
+      TestParams{
+          /*params_shape=*/{3, 2},
+          /*indices_shape=*/{2, 2},
+          /*indices=*/{0, 2, 1, 0},
+          /*axis=*/0,
+          /*expected_output_shape=*/{2, 2, 2},
+          /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
+          /*params_is_tensor=*/false,
+          /*indices_is_tensor=*/false,
+          /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+              ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+              " batch size of 1 when params and indices are both tensors or both"
+              " constants."}
+              : Status::OK()
       },
   };
 
   for (auto p : test_params) {
-    Reset();
-    if (p.params_is_tensor) {
-      AddTestTensor("params", p.params_shape, params_input);
-    } else {
-      AddTestWeights("params", p.params_shape, params_input, tf_type_);
+      Reset();
+
+      if (p.params_is_tensor) {
+        AddTestTensor("params", p.params_shape, params_input);
+      } else {
+        AddTestWeights("params", p.params_shape, params_input, tf_type_);
+      }
+
+      if (p.indices_is_tensor) {
+        AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
+                      p.add_index_status);
+      } else {
+        std::vector<int> indices_shape(p.indices_shape);
+        AddTestWeights("indices", indices_shape, p.indices, DT_INT32);
+      }
+
+      AddTestWeights<int32>("axis", {1}, {p.axis});
+      TestOpConverter("my_gather", node_def, p.expected_output_shape,
+                      p.conversion_status, p.runtime_status,
+                      ElementsAreArray(p.expected_output));
     }
-    AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
-                  p.add_index_status);
-    AddTestWeights<int32>("axis", {1}, {p.axis});
-    TestOpConverter("my_gather", node_def, p.expected_output_shape, p.status,
-                    p.runtime_status, ElementsAreArray(p.expected_output));
-  }
 }
 
 template <typename OpType>