Refactoring of the Converter for Binary Operations.
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index f1ef6fc..6877cd2 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -609,6 +609,7 @@
     srcs = [
         "convert/convert_graph.cc",
         "convert/convert_nodes.cc",
+        "convert/ops/binary_ops.cc",
         "convert/ops/data_format_vec_permute.cc",
         "convert/ops/einsum.cc",
         "convert/ops/fill_ops.cc",
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 478c2d8..7e7ebb6 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -3645,65 +3645,6 @@
   return Status::OK();
 }
 
-Status ConvertBinary(OpConverterParams* params) {
-  const auto& inputs = params->inputs;
-  const auto& node_def = params->node_def;
-  TFTRT_CHECK_INPUT_SIZE(inputs.size(), 2, node_def);
-
-  std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
-                                   DataType::DT_INT32};
-  TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
-
-  // Constant folding should have been done by TensorFlow
-  if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
-    return errors::Unimplemented(
-        "Constant folding is falled back to TensorFlow, binary op received "
-        "both input as constant");
-  }
-  const TRT_TensorOrWeights& operand_l = inputs.at(0);
-  const TRT_TensorOrWeights& operand_r = inputs.at(1);
-
-  auto op_pair = absl::c_find_if(
-      kBinaryOperations,
-      [&node_def](
-          const std::pair<std::string, nvinfer1::ElementWiseOperation>& x) {
-        return x.first == node_def.op();
-      });
-  if (op_pair == kBinaryOperations.end()) {
-    return errors::Unimplemented("Binary op ", node_def.op(), " not supported");
-  }
-
-  nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
-  TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
-      operand_l, operand_r, /*check_feasibility=*/true,
-      params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
-  ITensorProxyPtr tensor_l = nullptr;
-  ITensorProxyPtr tensor_r = nullptr;
-  // This will also convert constants to tensors.
-  TF_RETURN_IF_ERROR(PrepareTensorForShape(
-      params->converter, operand_l, broadcasted_dims_l, params->validation_only,
-      &tensor_l, node_def, /*op_instance=*/0));
-  TF_RETURN_IF_ERROR(PrepareTensorForShape(
-      params->converter, operand_r, broadcasted_dims_r, params->validation_only,
-      &tensor_r, node_def, /*op_instance=*/1));
-  if (params->validation_only) return Status::OK();
-
-  // Add ElementWise layer.
-  nvinfer1::ILayer* layer = params->converter->network()->addElementWise(
-      *tensor_l->trt_tensor(), *tensor_r->trt_tensor(), op_pair->second);
-  TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
-
-  if (params->use_explicit_precision) {
-    layer->setPrecision(nvinfer1::DataType::kFLOAT);
-  }
-
-  params->converter->SetLayerName(layer, node_def);
-  ITensorProxyPtr trt_tensor = layer->getOutput(0);
-
-  params->outputs->push_back(TRT_TensorOrWeights(trt_tensor));
-  return Status::OK();
-}
-
 Status ConvertRsqrt(OpConverterParams* params) {
   const auto& inputs = params->inputs;
   const auto& node_def = params->node_def;
@@ -5972,8 +5913,6 @@
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTopK, "TopKV2");
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTranspose, "Transpose");
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertUnpack, "Unpack");
-REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertBinary,
-                                  GetOperationNames(kBinaryOperations));
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertActivation,
                                   GetOperationNames(*ActivationTypeMap()));
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPool, {"MaxPool", "AvgPool"});
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 063a19d..e79a1bc 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -507,27 +507,16 @@
                             nvinfer1::Dims* operand_l_new_dims,
                             nvinfer1::Dims* operand_r_new_dims);
 
-// Map of all supported UnaryOperations
-const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
-// Map of all supported ActivationTypes
-const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
-// Map of all supported BinaryOperations
-const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
-BinaryOperationMap();
-
-constexpr std::array<std::pair<const char*, nvinfer1::ElementWiseOperation>, 10>
-    kBinaryOperations = {{
-        {"Add", nvinfer1::ElementWiseOperation::kSUM},
-        {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
-        {"Mul", nvinfer1::ElementWiseOperation::kPROD},
-        {"Sub", nvinfer1::ElementWiseOperation::kSUB},
-        {"Div", nvinfer1::ElementWiseOperation::kDIV},
-        {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
-        {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
-        {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
-        {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
-        {"Pow", nvinfer1::ElementWiseOperation::kPOW},
-    }};
+template <typename T>
+using operationMap = std::unordered_map<std::string, T>;
+// Map of all supported UnaryOperations.
+typedef operationMap<nvinfer1::UnaryOperation> unaryOperationMap;
+const unaryOperationMap* UnaryOperationMap();
+// Map of all supported ActivationTypes.
+const operationMap<nvinfer1::ActivationType>* ActivationTypeMap();
+// Map of all supported BinaryOperations.
+typedef operationMap<nvinfer1::ElementWiseOperation> binaryOperationMap;
+const binaryOperationMap* BinaryOperationMap();
 
 template <typename T>
 absl::InlinedVector<std::string, 10> GetOperationNames(const T& set) {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index ebe450c..cc75cc5 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -28,6 +28,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+
 #include "absl/algorithm/container.h"
 #include "absl/base/call_once.h"
 #include "absl/strings/match.h"
@@ -36,8 +37,6 @@
 #include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/nn_ops_internal.h"
@@ -67,6 +66,8 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
 #include "tensorflow/core/public/session.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "third_party/tensorrt/NvInfer.h"
 
 namespace tensorflow {
@@ -1795,12 +1796,12 @@
                        const Status& expected_runtime_status,
                        const Matcher<std::vector<float>>& matcher,
                        const std::vector<DataType>& out_tf_types = {}) {
-    RunValidationAndConversion(
-        node_def, expected_conversion_status, name,
-        std::vector<std::vector<int>>({expected_output_dims}));
+    const auto& exp_dims =
+        std::vector<std::vector<int>>({expected_output_dims});
+    RunValidationAndConversion(node_def, expected_conversion_status, name,
+                               exp_dims);
     if (expected_conversion_status.ok()) {
-      BuildAndRun(name, std::vector<std::vector<int>>({expected_output_dims}),
-                  expected_runtime_status,
+      BuildAndRun(name, exp_dims, expected_runtime_status,
                   std::vector<Matcher<std::vector<float>>>({matcher}),
                   out_tf_types);
     }
@@ -1813,6 +1814,76 @@
   DataVec input_data_;
 };
 
+template <typename T>
+class OpConverter_BinaryTest : public ParameterizedOpConverterTestBase {
+ public:
+  template <typename S>
+  void RunTests(
+      const operationMap<S>& map,
+      std::map<std::string,
+               std::pair<std::function<NodeDef(DataType)>, std::vector<T>>>&
+          op_test_info,
+      DataType tf_type) {
+    // Test combinations of tensor vs weight inputs (except when both inputs are
+    // weights).
+    bool expectedToFailTested = false;
+    for (const bool operand_1_is_tensor : {true, false}) {
+      for (const bool operand_2_is_tensor : {true, false}) {
+        const auto bothOperandsAreWeights =
+            !operand_1_is_tensor && !operand_2_is_tensor;
+        for (auto& iter : map) {
+          const string& op_name = iter.first;
+          SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
+                              operand_2_is_tensor ? "T" : "W"));
+
+          if (!op_test_info.count(op_name)) {
+            FAIL() << "Binary op test map does not contain op " << op_name;
+          }
+
+          if (!expectedToFailTested && bothOperandsAreWeights) {
+            runExpectedToFailTest(op_name);
+            expectedToFailTested = true;
+            break;
+          }
+
+          Reset();
+          if (operand_1_is_tensor) {
+            AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
+          } else {
+            AddTestWeights("input1", {1, 2}, std::vector<T>{3, 6}, tf_type);
+          }
+          if (operand_2_is_tensor) {
+            AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
+          } else {
+            AddTestWeights("input2", {2, 1}, std::vector<T>{2, 3}, tf_type);
+          }
+
+          const NodeDef& node_def = op_test_info[op_name].first(tf_type);
+          TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
+                          Status::OK(),
+                          ElementsAreArray(op_test_info[op_name].second));
+        }
+      }
+    }
+  }
+
+  void runExpectedToFailTest(const std::string& op_name) {
+    Reset();
+    AttrValue dtype;
+    dtype.set_type(tf_type_);
+
+    const auto& node_def = MakeNodeDef(
+        "my_oper", op_name, {"weights1", "weights2"}, {{"T", dtype}});
+    AddTestWeights("weights1", {1}, {1}, tf_type_);
+    AddTestWeights("weights2", {1}, {1}, tf_type_);
+    const string error =
+        "Constant folding is falled back to TensorFlow, "
+        "binary op '" +
+        op_name + "' received both input as constant";
+    RunValidationAndConversion(node_def, error::UNIMPLEMENTED, error);
+  }
+};
+
 // Op converter test in FP32 mode. While for debugging purposes it might make
 // sense to run over all possible combinations, normally a subset of them
 // would be sufficient:
@@ -1825,14 +1896,15 @@
 //   how TRT handles the precision inside the TRT network, but should not matter
 //   for the TF -> TRT conversion. Therefore it should be sufficient to test
 //   for FP32.
-class OpConverter_FP32_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_Test;
 // Base class for tests that need to be tested for both FP32 and FP16.
-class OpConverter_FP32_FP16_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_Test;
+// Base class for Binary tests that need to be tested
+typedef OpConverter_BinaryTest<float> OpConverter_FP32_FP16_BinaryTest;
 // Base class for tests that need to be tested for FP32, FP16, and INT32
-class OpConverter_FP32_FP16_INT32_Test
-    : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_INT32_Test;
 // Base class for tests that need to be tested for INT32
-class OpConverter_INT32_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_INT32_Test;
 
 // Instantiate parameter combinations to OpConverter_<DT_X...>_Test
 INSTANTIATE_TEST_CASE_P(
@@ -1859,6 +1931,12 @@
                        ::testing::Values(DT_INT32),
                        ::testing::Values(TrtPrecisionMode::FP32)));
 
+INSTANTIATE_TEST_CASE_P(
+    OpConvTestInstantiation, OpConverter_FP32_FP16_BinaryTest,
+    ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
+                       ::testing::Values(DT_FLOAT, DT_HALF),
+                       ::testing::Values(TrtPrecisionMode::FP32)));
+
 template <typename T>
 void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
   out->Clear();
@@ -3155,22 +3233,7 @@
   return op.operation.node()->def();
 }
 
-TEST_P(OpConverter_FP32_FP16_Test, ConvertBinary) {
-  {
-    AttrValue dtype;
-    dtype.set_type(tf_type_);
-    // Both inputs are weights.
-    Reset();
-    NodeDef node_def =
-        MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
-    AddTestWeights<float>("weights1", {1}, {1});
-    AddTestWeights<float>("weights2", {1}, {1});
-    RunValidationAndConversion(
-        node_def, error::UNIMPLEMENTED,
-        "Constant folding is falled back to TensorFlow, binary op received "
-        "both input as constant");
-  }
-
+TEST_P(OpConverter_FP32_FP16_BinaryTest, ConvertBinary) {
   using OpFunc = std::function<NodeDef(DataType)>;
   std::map<std::string, std::pair<OpFunc, std::vector<float>>> op_test_info;
 #define ADD_OP(name, op, v1, v2, v3, v4, v5, v6, v7, v8) \
@@ -3188,39 +3251,7 @@
   ADD_OP("Maximum", ops::Maximum, {3, 6, 3, 6, 3, 6, 3, 6});
   ADD_OP("Pow", ops::Pow, {9, 36, 27, 216, 9, 36, 27, 216});
 #undef ADD_OP
-  // Test combinations of tensor vs weight inputs (except when both inputs are
-  // weights).
-  for (const bool operand_1_is_tensor : {true, false}) {
-    for (const bool operand_2_is_tensor : {true, false}) {
-      if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
-      for (auto& iter : kBinaryOperations) {
-        string op_name = iter.first;
-        SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
-                            operand_2_is_tensor ? "T" : "W"));
-        Reset();
-        if (!op_test_info.count(op_name)) {
-          FAIL() << "Binary op test map does not contain op " << op_name;
-        }
-        NodeDef node_def = op_test_info[op_name].first(tf_type_);
-        std::vector<std::string> input_names;
-        std::vector<std::vector<int>> input_dims;
-        std::vector<std::vector<float>> input_values;
-        if (operand_1_is_tensor) {
-          AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
-        } else {
-          AddTestWeights("input1", {1, 2}, std::vector<float>{3, 6}, tf_type_);
-        }
-        if (operand_2_is_tensor) {
-          AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
-        } else {
-          AddTestWeights("input2", {2, 1}, std::vector<float>{2, 3}, tf_type_);
-        }
-        TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
-                        Status::OK(),
-                        ElementsAreArray(op_test_info[op_name].second));
-      }
-    }
-  }
+  RunTests(*BinaryOperationMap(), op_test_info, get_tf_type());
 }
 
 NodeDef GetAddNNodeDef(const std::vector<string>& input_names, DataType dtype) {
@@ -4041,10 +4072,10 @@
     }
     Reset();
     NodeDef node_def = op_map[op_name].first(tf_type_);
+    // std::exp in Softplus will overflow for input > 88.
     const std::vector<float> input = {-100, -2, -1, 0, 1, 88};
     AddTestTensor("input", p.input_dims, input);
 
-    // std::exp in Softplus will overflow for input > 88
     std::vector<float> output_values;
     std::transform(input.begin(), input.end(),
                    std::back_inserter(output_values), op_map[op_name].second);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc
new file mode 100644
index 0000000..7d03949
--- /dev/null
+++ b/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc
@@ -0,0 +1,128 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+
+#include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+const binaryOperationMap *BinaryOperationMap() {
+  static auto *const m = new binaryOperationMap({
+      {"Add", nvinfer1::ElementWiseOperation::kSUM},
+      {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
+      {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+      {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+      {"Div", nvinfer1::ElementWiseOperation::kDIV},
+      {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
+      {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
+      {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
+      {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
+      {"Pow", nvinfer1::ElementWiseOperation::kPOW},
+  });
+  return m;
+}
+
+class ConvertBinaryImpl {
+ protected:
+  ConvertBinaryImpl(const binaryOperationMap *pOperMap) : pOperMap_(pOperMap) {}
+
+  Status ImplValidate(const OpConverterParams &params) {
+    const auto &node_def = params.node_def;
+    const auto op = node_def.op();
+    const auto op_pair = pOperMap_->find(op);
+    if (op_pair == pOperMap_->end()) {
+      return errors::Unimplemented("Binary op: ", op, " not supported");
+    }
+
+    // Constant folding should have been done by TensorFlow.
+    const auto &inputs = params.inputs;
+    if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
+      return errors::Unimplemented(
+          "Constant folding is falled back to TensorFlow, binary op '", op,
+          "' received both input as constant");
+    }
+
+    nvinfer1::Dims broadcasted_dims[2];
+    TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
+        inputs.at(0), inputs.at(1), true, params.use_implicit_batch,
+        broadcasted_dims, broadcasted_dims + 1));
+
+    for (int i = 0; i < 2; i++) {
+      tensor_[i] = nullptr;
+      // This will also convert constants to tensors.
+      TF_RETURN_IF_ERROR(PrepareTensorForShape(
+          params.converter, inputs.at(i), broadcasted_dims[i],
+          params.validation_only, tensor_ + i, node_def, i));
+    }
+    operation_ = op_pair->second;
+    return Status::OK();
+  }
+
+  Status ImplConvert(const OpConverterParams &params) {
+    const auto &node_def = params.node_def;
+    // Add ElementWise layer.
+    nvinfer1::ILayer *layer = params.converter->network()->addElementWise(
+        *tensor_[0]->trt_tensor(), *tensor_[1]->trt_tensor(), operation_);
+    TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+    if (params.use_explicit_precision) {
+      layer->setPrecision(nvinfer1::DataType::kFLOAT);
+    }
+
+    params.converter->SetLayerName(layer, node_def);
+    params.outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
+    return Status::OK();
+  }
+  static constexpr std::array<InputArgSpec, 2> InputSpec() {
+    return std::array<InputArgSpec, 2>{
+        InputArgSpec::Create("x", TrtInputArg::kBoth),
+        InputArgSpec::Create("y", TrtInputArg::kBoth)};
+  }
+
+ private:
+  const binaryOperationMap *pOperMap_;
+  ITensorProxyPtr tensor_[2];
+  nvinfer1::ElementWiseOperation operation_;
+};
+
+class ConvertBinary : public OpConverterBase<ConvertBinary>,
+                      protected ConvertBinaryImpl {
+ public:
+  explicit ConvertBinary(OpConverterParams *params)
+      : OpConverterBase<ConvertBinary>(params),
+        ConvertBinaryImpl(BinaryOperationMap()) {}
+
+  static constexpr std::array<DataType, 3> AllowedDataTypes() {
+    return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
+  }
+
+  static constexpr std::array<InputArgSpec, 2> InputSpec() {
+    return ConvertBinaryImpl::InputSpec();
+  }
+
+  Status Validate() { return ImplValidate(*params_); }
+  Status Convert() { return ImplConvert(*params_); }
+};
+
+REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertBinary>(),
+                                  GetOperationNames(*BinaryOperationMap()));
+
+}  // namespace convert
+}  // namespace tensorrt
+}  // namespace tensorflow
+#endif  // GOOGLE_CUDA && GOOGLE_TENSORRT