Refactoring of the Converter for Binary Operations.
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index f1ef6fc..6877cd2 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -609,6 +609,7 @@
srcs = [
"convert/convert_graph.cc",
"convert/convert_nodes.cc",
+ "convert/ops/binary_ops.cc",
"convert/ops/data_format_vec_permute.cc",
"convert/ops/einsum.cc",
"convert/ops/fill_ops.cc",
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 478c2d8..7e7ebb6 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -3645,65 +3645,6 @@
return Status::OK();
}
-Status ConvertBinary(OpConverterParams* params) {
- const auto& inputs = params->inputs;
- const auto& node_def = params->node_def;
- TFTRT_CHECK_INPUT_SIZE(inputs.size(), 2, node_def);
-
- std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
- DataType::DT_INT32};
- TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
-
- // Constant folding should have been done by TensorFlow
- if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
- return errors::Unimplemented(
- "Constant folding is falled back to TensorFlow, binary op received "
- "both input as constant");
- }
- const TRT_TensorOrWeights& operand_l = inputs.at(0);
- const TRT_TensorOrWeights& operand_r = inputs.at(1);
-
- auto op_pair = absl::c_find_if(
- kBinaryOperations,
- [&node_def](
- const std::pair<std::string, nvinfer1::ElementWiseOperation>& x) {
- return x.first == node_def.op();
- });
- if (op_pair == kBinaryOperations.end()) {
- return errors::Unimplemented("Binary op ", node_def.op(), " not supported");
- }
-
- nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
- TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
- operand_l, operand_r, /*check_feasibility=*/true,
- params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
- ITensorProxyPtr tensor_l = nullptr;
- ITensorProxyPtr tensor_r = nullptr;
- // This will also convert constants to tensors.
- TF_RETURN_IF_ERROR(PrepareTensorForShape(
- params->converter, operand_l, broadcasted_dims_l, params->validation_only,
- &tensor_l, node_def, /*op_instance=*/0));
- TF_RETURN_IF_ERROR(PrepareTensorForShape(
- params->converter, operand_r, broadcasted_dims_r, params->validation_only,
- &tensor_r, node_def, /*op_instance=*/1));
- if (params->validation_only) return Status::OK();
-
- // Add ElementWise layer.
- nvinfer1::ILayer* layer = params->converter->network()->addElementWise(
- *tensor_l->trt_tensor(), *tensor_r->trt_tensor(), op_pair->second);
- TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
-
- if (params->use_explicit_precision) {
- layer->setPrecision(nvinfer1::DataType::kFLOAT);
- }
-
- params->converter->SetLayerName(layer, node_def);
- ITensorProxyPtr trt_tensor = layer->getOutput(0);
-
- params->outputs->push_back(TRT_TensorOrWeights(trt_tensor));
- return Status::OK();
-}
-
Status ConvertRsqrt(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
@@ -5972,8 +5913,6 @@
REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTopK, "TopKV2");
REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTranspose, "Transpose");
REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertUnpack, "Unpack");
-REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertBinary,
- GetOperationNames(kBinaryOperations));
REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertActivation,
GetOperationNames(*ActivationTypeMap()));
REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPool, {"MaxPool", "AvgPool"});
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 063a19d..e79a1bc 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -507,27 +507,16 @@
nvinfer1::Dims* operand_l_new_dims,
nvinfer1::Dims* operand_r_new_dims);
-// Map of all supported UnaryOperations
-const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
-// Map of all supported ActivationTypes
-const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
-// Map of all supported BinaryOperations
-const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
-BinaryOperationMap();
-
-constexpr std::array<std::pair<const char*, nvinfer1::ElementWiseOperation>, 10>
- kBinaryOperations = {{
- {"Add", nvinfer1::ElementWiseOperation::kSUM},
- {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
- {"Mul", nvinfer1::ElementWiseOperation::kPROD},
- {"Sub", nvinfer1::ElementWiseOperation::kSUB},
- {"Div", nvinfer1::ElementWiseOperation::kDIV},
- {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
- {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
- {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
- {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
- {"Pow", nvinfer1::ElementWiseOperation::kPOW},
- }};
+template <typename T>
+using operationMap = std::unordered_map<std::string, T>;
+// Map of all supported UnaryOperations.
+typedef operationMap<nvinfer1::UnaryOperation> unaryOperationMap;
+const unaryOperationMap* UnaryOperationMap();
+// Map of all supported ActivationTypes.
+const operationMap<nvinfer1::ActivationType>* ActivationTypeMap();
+// Map of all supported BinaryOperations.
+typedef operationMap<nvinfer1::ElementWiseOperation> binaryOperationMap;
+const binaryOperationMap* BinaryOperationMap();
template <typename T>
absl::InlinedVector<std::string, 10> GetOperationNames(const T& set) {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index ebe450c..cc75cc5 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -28,6 +28,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+
#include "absl/algorithm/container.h"
#include "absl/base/call_once.h"
#include "absl/strings/match.h"
@@ -36,8 +37,6 @@
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
@@ -67,6 +66,8 @@
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/public/session.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
@@ -1795,12 +1796,12 @@
const Status& expected_runtime_status,
const Matcher<std::vector<float>>& matcher,
const std::vector<DataType>& out_tf_types = {}) {
- RunValidationAndConversion(
- node_def, expected_conversion_status, name,
- std::vector<std::vector<int>>({expected_output_dims}));
+ const auto& exp_dims =
+ std::vector<std::vector<int>>({expected_output_dims});
+ RunValidationAndConversion(node_def, expected_conversion_status, name,
+ exp_dims);
if (expected_conversion_status.ok()) {
- BuildAndRun(name, std::vector<std::vector<int>>({expected_output_dims}),
- expected_runtime_status,
+ BuildAndRun(name, exp_dims, expected_runtime_status,
std::vector<Matcher<std::vector<float>>>({matcher}),
out_tf_types);
}
@@ -1813,6 +1814,76 @@
DataVec input_data_;
};
+template <typename T>
+class OpConverter_BinaryTest : public ParameterizedOpConverterTestBase {
+ public:
+ template <typename S>
+ void RunTests(
+ const operationMap<S>& map,
+ std::map<std::string,
+ std::pair<std::function<NodeDef(DataType)>, std::vector<T>>>&
+ op_test_info,
+ DataType tf_type) {
+ // Test combinations of tensor vs weight inputs (except when both inputs are
+ // weights).
+ bool expectedToFailTested = false;
+ for (const bool operand_1_is_tensor : {true, false}) {
+ for (const bool operand_2_is_tensor : {true, false}) {
+ const auto bothOperandsAreWeights =
+ !operand_1_is_tensor && !operand_2_is_tensor;
+ for (auto& iter : map) {
+ const string& op_name = iter.first;
+ SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
+ operand_2_is_tensor ? "T" : "W"));
+
+ if (!op_test_info.count(op_name)) {
+ FAIL() << "Binary op test map does not contain op " << op_name;
+ }
+
+ if (!expectedToFailTested && bothOperandsAreWeights) {
+ runExpectedToFailTest(op_name);
+ expectedToFailTested = true;
+ break;
+ }
+
+ Reset();
+ if (operand_1_is_tensor) {
+ AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
+ } else {
+ AddTestWeights("input1", {1, 2}, std::vector<T>{3, 6}, tf_type);
+ }
+ if (operand_2_is_tensor) {
+ AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
+ } else {
+ AddTestWeights("input2", {2, 1}, std::vector<T>{2, 3}, tf_type);
+ }
+
+ const NodeDef& node_def = op_test_info[op_name].first(tf_type);
+ TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
+ Status::OK(),
+ ElementsAreArray(op_test_info[op_name].second));
+ }
+ }
+ }
+ }
+
+ void runExpectedToFailTest(const std::string& op_name) {
+ Reset();
+ AttrValue dtype;
+ dtype.set_type(tf_type_);
+
+ const auto& node_def = MakeNodeDef(
+ "my_oper", op_name, {"weights1", "weights2"}, {{"T", dtype}});
+ AddTestWeights("weights1", {1}, {1}, tf_type_);
+ AddTestWeights("weights2", {1}, {1}, tf_type_);
+ const string error =
+ "Constant folding is falled back to TensorFlow, "
+ "binary op '" +
+ op_name + "' received both input as constant";
+ RunValidationAndConversion(node_def, error::UNIMPLEMENTED, error);
+ }
+};
+
// Op converter test in FP32 mode. While for debugging purposes it might make
// sense to run over all possible combinations, normally a subset of them
// would be sufficient:
@@ -1825,14 +1896,15 @@
// how TRT handles the precision inside the TRT network, but should not matter
// for the TF -> TRT conversion. Therefore it should be sufficient to test
// for FP32.
-class OpConverter_FP32_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_Test;
// Base class for tests that need to be tested for both FP32 and FP16.
-class OpConverter_FP32_FP16_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_Test;
+// Base class for Binary tests that need to be tested
+typedef OpConverter_BinaryTest<float> OpConverter_FP32_FP16_BinaryTest;
// Base class for tests that need to be tested for FP32, FP16, and INT32
-class OpConverter_FP32_FP16_INT32_Test
- : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_FP32_FP16_INT32_Test;
// Base class for tests that need to be tested for INT32
-class OpConverter_INT32_Test : public ParameterizedOpConverterTestBase {};
+typedef ParameterizedOpConverterTestBase OpConverter_INT32_Test;
// Instantiate parameter combinations to OpConverter_<DT_X...>_Test
INSTANTIATE_TEST_CASE_P(
@@ -1859,6 +1931,12 @@
::testing::Values(DT_INT32),
::testing::Values(TrtPrecisionMode::FP32)));
+INSTANTIATE_TEST_CASE_P(
+ OpConvTestInstantiation, OpConverter_FP32_FP16_BinaryTest,
+ ::testing::Combine(::testing::ValuesIn(ValidTrtModes),
+ ::testing::Values(DT_FLOAT, DT_HALF),
+ ::testing::Values(TrtPrecisionMode::FP32)));
+
template <typename T>
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
out->Clear();
@@ -3155,22 +3233,7 @@
return op.operation.node()->def();
}
-TEST_P(OpConverter_FP32_FP16_Test, ConvertBinary) {
- {
- AttrValue dtype;
- dtype.set_type(tf_type_);
- // Both inputs are weights.
- Reset();
- NodeDef node_def =
- MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
- AddTestWeights<float>("weights1", {1}, {1});
- AddTestWeights<float>("weights2", {1}, {1});
- RunValidationAndConversion(
- node_def, error::UNIMPLEMENTED,
- "Constant folding is falled back to TensorFlow, binary op received "
- "both input as constant");
- }
-
+TEST_P(OpConverter_FP32_FP16_BinaryTest, ConvertBinary) {
using OpFunc = std::function<NodeDef(DataType)>;
std::map<std::string, std::pair<OpFunc, std::vector<float>>> op_test_info;
#define ADD_OP(name, op, v1, v2, v3, v4, v5, v6, v7, v8) \
@@ -3188,39 +3251,7 @@
ADD_OP("Maximum", ops::Maximum, {3, 6, 3, 6, 3, 6, 3, 6});
ADD_OP("Pow", ops::Pow, {9, 36, 27, 216, 9, 36, 27, 216});
#undef ADD_OP
- // Test combinations of tensor vs weight inputs (except when both inputs are
- // weights).
- for (const bool operand_1_is_tensor : {true, false}) {
- for (const bool operand_2_is_tensor : {true, false}) {
- if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
- for (auto& iter : kBinaryOperations) {
- string op_name = iter.first;
- SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
- operand_2_is_tensor ? "T" : "W"));
- Reset();
- if (!op_test_info.count(op_name)) {
- FAIL() << "Binary op test map does not contain op " << op_name;
- }
- NodeDef node_def = op_test_info[op_name].first(tf_type_);
- std::vector<std::string> input_names;
- std::vector<std::vector<int>> input_dims;
- std::vector<std::vector<float>> input_values;
- if (operand_1_is_tensor) {
- AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
- } else {
- AddTestWeights("input1", {1, 2}, std::vector<float>{3, 6}, tf_type_);
- }
- if (operand_2_is_tensor) {
- AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
- } else {
- AddTestWeights("input2", {2, 1}, std::vector<float>{2, 3}, tf_type_);
- }
- TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
- Status::OK(),
- ElementsAreArray(op_test_info[op_name].second));
- }
- }
- }
+ RunTests(*BinaryOperationMap(), op_test_info, get_tf_type());
}
NodeDef GetAddNNodeDef(const std::vector<string>& input_names, DataType dtype) {
@@ -4041,10 +4072,10 @@
}
Reset();
NodeDef node_def = op_map[op_name].first(tf_type_);
+ // std::exp in Softplus will overflow for input > 88.
const std::vector<float> input = {-100, -2, -1, 0, 1, 88};
AddTestTensor("input", p.input_dims, input);
- // std::exp in Softplus will overflow for input > 88
std::vector<float> output_values;
std::transform(input.begin(), input.end(),
std::back_inserter(output_values), op_map[op_name].second);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc
new file mode 100644
index 0000000..7d03949
--- /dev/null
+++ b/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc
@@ -0,0 +1,128 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+
+#include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+const binaryOperationMap *BinaryOperationMap() {
+ static auto *const m = new binaryOperationMap({
+ {"Add", nvinfer1::ElementWiseOperation::kSUM},
+ {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
+ {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+ {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+ {"Div", nvinfer1::ElementWiseOperation::kDIV},
+ {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
+ {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
+ {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
+ {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
+ {"Pow", nvinfer1::ElementWiseOperation::kPOW},
+ });
+ return m;
+}
+
+class ConvertBinaryImpl {
+ protected:
+ ConvertBinaryImpl(const binaryOperationMap *pOperMap) : pOperMap_(pOperMap) {}
+
+ Status ImplValidate(const OpConverterParams ¶ms) {
+ const auto &node_def = params.node_def;
+ const auto op = node_def.op();
+ const auto op_pair = pOperMap_->find(op);
+ if (op_pair == pOperMap_->end()) {
+ return errors::Unimplemented("Binary op: ", op, " not supported");
+ }
+
+ // Constant folding should have been done by TensorFlow.
+ const auto &inputs = params.inputs;
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
+ return errors::Unimplemented(
+ "Constant folding is falled back to TensorFlow, binary op '", op,
+ "' received both input as constant");
+ }
+
+ nvinfer1::Dims broadcasted_dims[2];
+ TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
+ inputs.at(0), inputs.at(1), true, params.use_implicit_batch,
+ broadcasted_dims, broadcasted_dims + 1));
+
+ for (int i = 0; i < 2; i++) {
+ tensor_[i] = nullptr;
+ // This will also convert constants to tensors.
+ TF_RETURN_IF_ERROR(PrepareTensorForShape(
+ params.converter, inputs.at(i), broadcasted_dims[i],
+ params.validation_only, tensor_ + i, node_def, i));
+ }
+ operation_ = op_pair->second;
+ return Status::OK();
+ }
+
+ Status ImplConvert(const OpConverterParams ¶ms) {
+ const auto &node_def = params.node_def;
+ // Add ElementWise layer.
+ nvinfer1::ILayer *layer = params.converter->network()->addElementWise(
+ *tensor_[0]->trt_tensor(), *tensor_[1]->trt_tensor(), operation_);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+ if (params.use_explicit_precision) {
+ layer->setPrecision(nvinfer1::DataType::kFLOAT);
+ }
+
+ params.converter->SetLayerName(layer, node_def);
+ params.outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
+ return Status::OK();
+ }
+ static constexpr std::array<InputArgSpec, 2> InputSpec() {
+ return std::array<InputArgSpec, 2>{
+ InputArgSpec::Create("x", TrtInputArg::kBoth),
+ InputArgSpec::Create("y", TrtInputArg::kBoth)};
+ }
+
+ private:
+ const binaryOperationMap *pOperMap_;
+ ITensorProxyPtr tensor_[2];
+ nvinfer1::ElementWiseOperation operation_;
+};
+
+class ConvertBinary : public OpConverterBase<ConvertBinary>,
+ protected ConvertBinaryImpl {
+ public:
+ explicit ConvertBinary(OpConverterParams *params)
+ : OpConverterBase<ConvertBinary>(params),
+ ConvertBinaryImpl(BinaryOperationMap()) {}
+
+ static constexpr std::array<DataType, 3> AllowedDataTypes() {
+ return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
+ }
+
+ static constexpr std::array<InputArgSpec, 2> InputSpec() {
+ return ConvertBinaryImpl::InputSpec();
+ }
+
+ Status Validate() { return ImplValidate(*params_); }
+ Status Convert() { return ImplConvert(*params_); }
+};
+
+REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertBinary>(),
+ GetOperationNames(*BinaryOperationMap()));
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT