[TFTRT] Add Dynamic Shape Testing for ConvertArgMinMax

diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index eed37cd..3718a18 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -5368,9 +5368,10 @@
 
   const auto get_matrix_op = [](nvinfer1::ITensor* in,
                                 bool transpose) -> nvinfer1::MatrixOperation {
-    return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR
-           : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
-                         : nvinfer1::MatrixOperation::kNONE;
+    return (in->getDimensions().nbDims < 2)
+               ? nvinfer1::MatrixOperation::kVECTOR
+               : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
+                             : nvinfer1::MatrixOperation::kNONE;
   };
 
   // If the MatMul operand is a constant, applies transposes at conversion-time
@@ -5550,7 +5551,7 @@
   int trt_axis;
   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
-                                 /*use_implicit_batch=*/true, &trt_axis));
+                                 params->use_implicit_batch, &trt_axis));
   nvinfer1::TopKOperation topk_op;
   if (node_def.op() == "ArgMin") {
     topk_op = nvinfer1::TopKOperation::kMIN;
@@ -5559,6 +5560,18 @@
   } else {
     return errors::InvalidArgument("Unsupported ArgMin/Max operation");
   }
+
+#if !IS_TRT_VERSION_GE(7, 0, 0, 11)
+  const nvinfer1::Dims trt_dims = params->inputs.at(0).GetTrtDims();
+  if (trt_dims.nbDims >= 4) {
+    string trt_dim_str = DebugString(trt_dims);
+
+    return errors::Unimplemented(node_def.op(), "op is not able to support",
+                                 " tensors with 4+ dimensions (excluding batch",
+                                 " size). Received: ", trt_dim_str);
+  }
+#endif
+
   if (params->validation_only) return Status::OK();
 
   // Use TopK with k = 1. Only indices output is needed (output 1).
@@ -5570,16 +5583,13 @@
   nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
 
   // Squeeze on axis.
-  std::vector<int> size(dims.d, dims.d + dims.nbDims);
-  size.erase(size.begin() + trt_axis);
-  nvinfer1::Dims new_dims;
-  TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &new_dims));
+  std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
+  input_dims[trt_axis] = 0;
   nvinfer1::ITensor* output_tensor = nullptr;
-  TF_RETURN_IF_ERROR(PrepareTensorForShape(
-      params->converter, TRT_TensorOrWeights(output_indices_tensor), new_dims,
-      /*validation_only=*/false, &output_tensor, node_def));
-
+  TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
+      output_indices_tensor, &input_dims, params, &output_tensor));
   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
+
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 65d5e2c..7a8d335 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -32,8 +32,6 @@
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/nn_ops_internal.h"
@@ -58,6 +56,8 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
 #include "tensorflow/core/public/session.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "third_party/tensorrt/NvInfer.h"
 
 namespace tensorflow {
@@ -6218,99 +6218,49 @@
   return arg.operation.node()->def();
 }
 
-template <typename OpType, DataType dtype>
-void TestConvertArgMinMax(OpConverterTest* test) {
-  typedef typename EnumToDataType<dtype>::Type CType;
+struct ArgMinMaxTestParams {
+  std::vector<int> input_shape;
+  std::vector<float> input_value;
+  int axis;
+  std::vector<int> expected_output_dims;
+  std::vector<int> expected_argmax_output;
+  std::vector<int> expected_argmin_output;
+  Status status;
+};
 
-  struct TestParams {
-    std::vector<int> input_shape;
-    std::vector<CType> input_value;
-    int axis;
-    std::vector<int> expected_output_dims;
-    std::vector<int> expected_argmax_output;
-    std::vector<int> expected_argmin_output;
-  };
+template <typename OpType>
+void TestConvertArgMinMax(ParameterizedOpConverterTestBase* test,
+                          DataType _tf_type, ArgMinMaxTestParams& p) {
+  test->Reset();
 
-  const std::vector<CType> common_input = InitTestVector<CType>(6);
-  std::vector<TestParams> params = {
-      {
-          /*input_shape=*/{2, 3},
-          /*input_value=*/common_input,
-          /*axis=*/2,
-          /*expected_output_dims=*/{2},
-          /*expected_argmax_output=*/{2, 2},
-          /*expected_argmin_output=*/{0, 0},
-      },
-      {
-          /*input_shape=*/{2, 3},
-          /*input_value=*/common_input,
-          /*axis=*/-2,
-          /*expected_output_dims=*/{3},
-          /*expected_argmax_output=*/{1, 1, 1},
-          /*expected_argmin_output=*/{0, 0, 0},
-      },
-      {
-          /*input_shape=*/{6},
-          /*input_value=*/common_input,
-          /*axis=*/1,
-          /*expected_output_dims=*/{},
-          /*expected_argmax_output=*/{5},
-          /*expected_argmin_output=*/{0},
-      },
-      {
-          /*input_shape=*/{10},
-          /*input_value=*/
-          {CType(-5), CType(3), CType(5), CType(1), CType(6), CType(-9),
-           CType(7), CType(1), CType(0), CType(-1)},
-          /*axis=*/-1,
-          /*expected_output_dims=*/{},
-          /*expected_argmax_output=*/{6},
-          /*expected_argmin_output=*/{5},
-      },
-  };
+  NodeDef node_def = GetArgMinMaxNodeDef<OpType>(_tf_type,
+                                                 /*output_dtype=*/DT_INT32);
 
-  for (int i = 0; i < params.size(); ++i) {
-    test->Reset();
-
-    NodeDef node_def = GetArgMinMaxNodeDef<OpType>(dtype, DT_INT32);
-    // Create inputs.
-    nvinfer1::DataType trt_type;
-    TF_ASSERT_OK(TfTypeToTrtType(dtype, &trt_type));
-    test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1,
-                        /*trt_dtype=*/trt_type);
-    test->AddTestWeights<int32>("dimension", {1}, {params[i].axis});
-    test->RunValidationAndConversion(node_def);
-
-    TRT_TensorOrWeights output;
-    TF_EXPECT_OK(test->GetTensorOrWeights("my_arg", &output));
-    EXPECT_TRUE(output.is_tensor());
-    ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
-                             output.tensor()->getDimensions());
-    // Create input data for tensors.
-    const DataVec input_data{
-        {"input", test->AsTensor<CType>(params[i].input_value)}};
-    DataVec output_data{
-        {"my_arg", test->ConstructTensor<int32>(
-                       params[i].expected_argmax_output.size())}};
-    TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
-
-    if (node_def.op() == "ArgMax") {
-      EXPECT_THAT(GetSpanForData<int32>(output_data[0]),
-                  ElementsAreArray(params[i].expected_argmax_output));
-    } else if (node_def.op() == "ArgMin") {
-      EXPECT_THAT(GetSpanForData<int32>(output_data[0]),
-                  ElementsAreArray(params[i].expected_argmin_output));
-    } else {
-      ASSERT_TRUE(false);
-    }
+  std::vector<int> expected_out;
+  if (node_def.op() == "ArgMax") {
+    expected_out = p.expected_argmax_output;
+  } else if (node_def.op() == "ArgMin") {
+    expected_out = p.expected_argmin_output;
+  } else {
+    ASSERT_TRUE(false);
   }
+
+  test->AddTestTensor("input", p.input_shape, _tf_type, p.input_value);
+  test->AddTestWeights("dimension", {1}, {p.axis}, DT_INT32);
+
+  test->TestOpConverter("my_arg", node_def, p.expected_output_dims,
+                        /*expected_conversion_status=*/p.status,
+                        /*expected_runtime_status=*/Status::OK(),
+                        /*matcher=*/ElementsAreArray(expected_out), {DT_INT32});
 }
 
-TEST_F(OpConverterTest, ConvertArgMinMax) {
+TEST_P(OpConverter_FP32_FP16_Test, ConvertArgMinMax) {
   {
     // Dimension is a tensor, should fail.
     Reset();
-    NodeDef node_def = GetArgMinMaxNodeDef<ops::ArgMax>(DT_FLOAT, DT_INT32);
+    NodeDef node_def =
+        GetArgMinMaxNodeDef<ops::ArgMax>(tf_type_,
+                                         /*output_dtype=*/DT_INT32);
     AddTestTensor("input", {1, 2, 3});
     AddTestTensor("dimension", {1});
     RunValidationAndConversion(
@@ -6320,32 +6270,112 @@
   {
     // Output type is INT64, should fail.
     Reset();
-    NodeDef node_def = GetArgMinMaxNodeDef<ops::ArgMax>(DT_FLOAT, DT_INT64);
+    NodeDef node_def =
+        GetArgMinMaxNodeDef<ops::ArgMax>(tf_type_,
+                                         /*output_dtype=*/DT_INT64);
     AddTestTensor("input", {1, 2, 3});
-    AddTestWeights<int32>("dimension", {1}, {3});
+    AddTestWeights("dimension", {1}, {3}, DT_INT32);
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
                                "Output type int64 is not supported, at my_arg");
   }
-  {
-    // Axis is batch dimension, should fail
-    Reset();
-    NodeDef node_def = GetArgMinMaxNodeDef<ops::ArgMax>(DT_FLOAT, DT_INT32);
-    AddTestTensor("input", {1, 2, 3});
-    AddTestWeights<int32>("dimension", {1}, {0});
-    RunValidationAndConversion(
-        node_def, error::UNIMPLEMENTED,
-        "TensorRT does not allow manipulation of the batch dimension, at "
-        "my_arg");
-  }
 
-  TestConvertArgMinMax<ops::ArgMin, DT_FLOAT>(this);
-  TestConvertArgMinMax<ops::ArgMax, DT_FLOAT>(this);
-  TestConvertArgMinMax<ops::ArgMin, DT_HALF>(this);
-  TestConvertArgMinMax<ops::ArgMax, DT_HALF>(this);
-  // TRT does not support int32 for TopK layer which is used to implement ArgMin
-  // and ArgMax.
-  // TestConvertArgMinMax<ops::ArgMin, DT_INT32>(this);
-  // TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
+  const std::vector<float> common_input = InitTestVector<float>(6);
+  std::vector<ArgMinMaxTestParams> params = {
+      {
+          /*input_shape=*/{2, 3},
+          /*input_value=*/common_input,
+          /*axis=*/0,
+          /*expected_output_dims=*/{3},
+          /*expected_argmax_output=*/{1, 1, 1},
+          /*expected_argmin_output=*/{0, 0, 0},
+          trt_mode_ == TrtTestMode::kImplicitBatch
+              ? errors::Unimplemented("TensorRT does not allow manipulation of "
+                                      "the batch dimension, at my_arg")
+              : Status::OK()
+      },
+      {
+          /*input_shape=*/{1, 6},
+          /*input_value=*/common_input,
+          /*axis=*/1,
+          /*expected_output_dims=*/{1},
+          /*expected_argmax_output=*/{5},
+          /*expected_argmin_output=*/{0},
+      },
+      {
+          /*input_shape=*/{1, 10},
+          /*input_value=*/
+          {-5.0f, 3.0f, 5.0f, 1.0f, 6.0f, -9.0f, 7.0f, 1.0f, 0.0f, -1.0f},
+          /*axis=*/-1,
+          /*expected_output_dims=*/{1},
+          /*expected_argmax_output=*/{6},
+          /*expected_argmin_output=*/{5},
+      },
+      {
+          /*input_shape=*/{1, 2, 3},
+          /*input_value=*/common_input,
+          /*axis=*/2,
+          /*expected_output_dims=*/{1, 2},
+          /*expected_argmax_output=*/{2, 2},
+          /*expected_argmin_output=*/{0, 0},
+      },
+      {
+          /*input_shape=*/{1, 2, 3},
+          /*input_value=*/common_input,
+          /*axis=*/-2,
+          /*expected_output_dims=*/{1, 3},
+          /*expected_argmax_output=*/{1, 1, 1},
+          /*expected_argmin_output=*/{0, 0, 0},
+      },
+      {
+          /*input_shape=*/{1, 2, 1, 3},
+          /*input_value=*/common_input,
+          /*axis=*/3,
+          /*expected_output_dims=*/{1, 2, 1},
+          /*expected_argmax_output=*/{2, 2},
+          /*expected_argmin_output=*/{0, 0},
+      },
+      {
+          /*input_shape=*/{1, 2, 1, 3},
+          /*input_value=*/common_input,
+          /*axis=*/-3,
+          /*expected_output_dims=*/{1, 1, 3},
+          /*expected_argmax_output=*/{1, 1, 1},
+          /*expected_argmin_output=*/{0, 0, 0},
+      },
+      {
+          /*input_shape=*/{1, 2, 1, 1, 3},
+          /*input_value=*/common_input,
+          /*axis=*/4,
+          /*expected_output_dims=*/{1, 2, 1, 1},
+          /*expected_argmax_output=*/{2, 2},
+          /*expected_argmin_output=*/{0, 0},
+#if !IS_TRT_VERSION_GE(7, 0, 0, 11)
+          errors::Unimplemented("op is not able to support tensors with 4+"
+                                " dimensions (excluding batch size)")
+#else
+          Status::OK()
+#endif
+      },
+      {
+          /*input_shape=*/{1, 2, 1, 1, 3},
+          /*input_value=*/common_input,
+          /*axis=*/-4,
+          /*expected_output_dims=*/{1, 1, 1, 3},
+          /*expected_argmax_output=*/{1, 1, 1},
+          /*expected_argmin_output=*/{0, 0, 0},
+#if !IS_TRT_VERSION_GE(7, 0, 0, 11)
+          errors::Unimplemented("op is not able to support tensors with 4+"
+                                " dimensions (excluding batch size)")
+#else
+          Status::OK()
+#endif
+      },
+  };
+
+  for (auto p : params) {
+    TestConvertArgMinMax<ops::ArgMin>(this, tf_type_, p);
+    TestConvertArgMinMax<ops::ArgMax>(this, tf_type_, p);
+  }
 }
 
 // Get the NodeDef for DepthToSpace or SpaceToSpace.