TF-TRT: Support GatherV2 op
Convert 'indices' input to a GatherV2 op into a constant
layer if it is a constant.
Modify the unit tests to generate constant indices.
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 456696c..8d72b64 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -4767,7 +4767,7 @@
// option for an input to be either tensor or weight.
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"params", TrtInputArg::kBoth},
- {"indices", TrtInputArg::kTensor},
+ {"indices", TrtInputArg::kBoth},
{"axis", TrtInputArg::kWeight}}));
const auto& params_input = inputs.at(0);
@@ -4794,20 +4794,24 @@
return errors::Unimplemented(
"The input axis must be zero when params is a weight.");
}
- if (params->use_implicit_batch && params_input.is_tensor() &&
- indices_input.batch_size() != 1) {
+ if (params->use_implicit_batch &&
+ (params_input.is_tensor() == indices_input.is_tensor()) &&
+ (indices_input.batch_size() != 1 || params_input.batch_size() != 1)) {
return errors::Unimplemented(
- "Indices must have a batch size of 1 when params is a tensor.");
+ "Params and indices must have a batch size of 1 when params and indices"
+ " are both tensors or both constants.");
}
+
+ auto get_rank = [params](const auto& input) {
+ return input.GetTrtDims().nbDims +
+ (params->use_implicit_batch && input.is_tensor() ? 1 : 0);
+ };
// Both input are tensors, and the TF gather result will have rank:
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
// where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
// the TF rank so we don't have to add + 1.
- const int params_tf_rank =
- params_input.GetTrtDims().nbDims +
- (params->use_implicit_batch && params_input.is_tensor() ? 1 : 0);
- const int indices_tf_rank =
- indices_input.GetTrtDims().nbDims + (params->use_implicit_batch ? 1 : 0);
+ const int params_tf_rank = get_rank(params_input);
+ const int indices_tf_rank = get_rank(indices_input);
const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
if (tf_gather_output_rank >
nvinfer1::Dims::MAX_DIMS + (params->use_implicit_batch ? 1 : 0)) {
@@ -4817,14 +4821,25 @@
}
if (params->validation_only) return Status::OK();
- // Convert params to tensor is it is a weight.
- ITensorProxyPtr params_tensor = nullptr;
- if (params_input.is_weights()) {
- params_tensor = params->converter->CreateConstantLayer(
- params_input.weights(), params_input.GetTrtDims());
- } else {
- params_tensor = params_input.tensor();
- }
+
+ // Convert input or indices to tensor if it is a constant.
+ auto populate_tensor =
+ [params](const auto& input) -> ITensorProxyPtr {
+
+ ITensorProxyPtr result_tensor = nullptr;
+
+ if (input.is_weights()) {
+ result_tensor = params->converter->CreateConstantLayer(
+ input.weights(), input.GetTrtDims());
+ } else {
+ result_tensor = input.tensor();
+ }
+
+ return result_tensor;
+ };
+
+ ITensorProxyPtr params_tensor = populate_tensor(params_input);
+ ITensorProxyPtr indices_tensor = populate_tensor(indices_input);
// Note on how IGatherLayer works: if both the data and indices tensors have
// a batch size dimension of size N, it performs:
@@ -4832,28 +4847,35 @@
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
- *params_tensor->trt_tensor(), *indices_input.tensor()->trt_tensor(),
+ *params_tensor->trt_tensor(), *indices_tensor->trt_tensor(),
trt_axis);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
params->converter->SetLayerName(layer, node_def);
ITensorProxyPtr output_tensor = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
- // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
- // and the other is for the output dimension that is squeezed by IGatherLayer
- // because of the implicit batch dim in the indices (see the above note).
- const int expected_trt_output_rank =
- tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
- if (params->use_implicit_batch &&
- trt_gather_output_dims.nbDims != expected_trt_output_rank) {
- return errors::Internal(
- "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
- expected_trt_output_rank,
- ", actual nbDims: ", trt_gather_output_dims.nbDims);
+
+ if (params->use_implicit_batch) {
+ // Note for the "- 2": one is for the output batch dim encapsulated by
+ // TF-TRT, and the other is for the output dimension that is squeezed by
+ // IGatherLayer because of the implicit batch dim in the indices (see the
+ // above note).
+ const int expected_trt_output_rank = tf_gather_output_rank -
+ (params_input.is_tensor() ? 1 : 0) -
+ (indices_input.is_tensor() ? 1 : 0);
+
+ if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
+ return errors::Internal(
+ "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
+ expected_trt_output_rank,
+ ", actual nbDims: ", trt_gather_output_dims.nbDims);
+ }
}
// Reshape the output so after adding the implicit batch dim it'll match the
// output shape of TF GatherV2.
- if (params->use_implicit_batch && params_input.is_tensor()) {
+ if (params->use_implicit_batch &&
+ params_input.is_tensor() &&
+ indices_input.is_tensor()) {
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
@@ -4866,6 +4888,26 @@
/*validation_only=*/false, &output_tensor, node_def));
}
+ // When input and indices are both constants, for the supported cases, reshape
+ // output so that after removing the implicit batch dim it will match the
+ // output shape of TF GatherV2 op.
+ if (params->use_implicit_batch &&
+ params_input.is_weights() &&
+ indices_input.is_weights()) {
+ for (int i = trt_axis; i < trt_gather_output_dims.nbDims - 1; ++i) {
+ trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i + 1];
+ }
+
+ // Squeeze the implicit batch dimension out. Note: this works only
+ // when batch size for both inputs and indices are 1.
+ --trt_gather_output_dims.nbDims;
+
+ TF_RETURN_IF_ERROR(PrepareTensorForShape(
+ params->converter, TRT_TensorOrWeights(output_tensor),
+ trt_gather_output_dims,
+ /*validation_only=*/false, &output_tensor, node_def));
+ }
+
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 10fe1ec..12db165 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -5755,74 +5755,97 @@
std::vector<int> expected_output_shape;
std::vector<int> expected_output;
bool params_is_tensor;
- Status status;
+ bool indices_is_tensor;
+ Status conversion_status;
Status runtime_status;
Status add_index_status;
};
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
const std::vector<int> params_input = {1, 2, 3, 4, 5, 6};
+
std::vector<TestParams> test_params = {
// Axis is batch dimension, should fail in implicit batch mode.
- TestParams{/*params_shape=*/{2, 1, 1, 3},
- /*indices_shape=*/{2},
- /*indices=*/{1, 0},
- /*axis=*/0,
- /*expected_output_shape=*/{2, 1, 1, 3},
- /*expected_output=*/{4, 5, 6, 1, 2, 3},
- /*params_is_tensor=*/true,
- trt_mode_ == TrtTestMode::kImplicitBatch
- ? Status{error::UNIMPLEMENTED,
- "TensorRT does not allow manipulation of the"
- " batch dimension, at my_gather"}
- : Status::OK()},
- // Batch size of indices is not 1 when params is a tensor.
- TestParams{/*params_shape=*/{2, 1, 3},
- /*indices_shape=*/{2, 1},
- /*indices=*/{2, 0},
- /*axis=*/2,
- /*expected_output_shape=*/{2, 1, 2, 1},
- /*expected_output=*/{3, 1, 6, 4},
- /*params_is_tensor=*/true,
- trt_mode_ == TrtTestMode::kImplicitBatch
- ? Status{error::UNIMPLEMENTED,
- "Indices must have a batch size of 1 when params"
- " is a tensor."}
- : Status::OK()},
+ TestParams{
+ /*params_shape=*/{2, 1, 1, 3},
+ /*indices_shape=*/{2},
+ /*indices=*/{1, 0},
+ /*axis=*/0,
+ /*expected_output_shape=*/{2, 1, 1, 3},
+ /*expected_output=*/{4, 5, 6, 1, 2, 3},
+ /*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "TensorRT does not allow "
+ "manipulation of the batch dimension, at my_gather"}
+ : Status::OK()
+ },
+ // Batch size of indices is not 1 when params and indices are tensors.
+ TestParams{
+ /*params_shape=*/{2, 1, 3},
+ /*indices_shape=*/{2, 1},
+ /*indices=*/{2, 0},
+ /*axis=*/2,
+ /*expected_output_shape=*/{2, 1, 2, 1},
+ /*expected_output=*/{3, 1, 6, 4},
+ /*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+ " batch size of 1 when params and indices are both tensors or both"
+ " constants."}
+ : Status::OK()
+ },
+ // Batch size of indices is not 1 when params is tensor and indices are
+ // constant.
+ TestParams{
+ /*params_shape=*/{2, 1, 3},
+ /*indices_shape=*/{2, 1},
+ /*indices=*/{2, 0},
+ /*axis=*/2,
+ /*expected_output_shape=*/{2, 1, 2, 1},
+ /*expected_output=*/{3, 1, 6, 4},
+ /*params_is_tensor=*/true,
+ /*indices_is_tensor=*/false,
+ /*conversion_status=*/Status::OK()
+ },
// Axis is not zero when params is a weight, should fail in implicit batch
// mode.
- TestParams{/*params_shape=*/{2, 1, 3},
- /*indices_shape=*/{2},
- /*indices=*/{1, 2},
- /*axis=*/2,
- /*expected_output_shape=*/{2, 1, 2},
- /*expected_output=*/{2, 3, 5, 6},
- /*params_is_tensor=*/false,
- trt_mode_ == TrtTestMode::kImplicitBatch
- ? Status{error::UNIMPLEMENTED,
- "The input axis must be zero when params is a"
- " weight."}
- : Status::OK()},
+ TestParams{
+ /*params_shape=*/{2, 1, 3},
+ /*indices_shape=*/{2},
+ /*indices=*/{1, 2},
+ /*axis=*/2,
+ /*expected_output_shape=*/{2, 1, 2},
+ /*expected_output=*/{2, 3, 5, 6},
+ /*params_is_tensor=*/false,
+ /*indices_is_tensor=*/true,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "The input axis must be zero when "
+ "params is a weight."}
+ : Status::OK()
+ },
// Params with only batch dimension.
- TestParams{/*params_shape=*/{6},
- /*indices_shape=*/{2},
- /*indices=*/{1, 3},
- /*axis=*/0,
- /*expected_output_shape=*/{2},
- /*expected_output=*/{2, 4},
- /*params_is_tensor=*/true,
- trt_mode_ == TrtTestMode::kImplicitBatch // conversion_status
- ? Status{error::UNIMPLEMENTED,
- "TensorRT does not allow manipulation of the "
- "batch dimension, at my_gather"}
- : Status::OK(),
- Status::OK(), // runtime_status
- trt_mode_ == TrtTestMode::kImplicitBatch // add_index_status
- ? Status{error::INVALID_ARGUMENT,
- "Batch size doesn't match for tensor indices: "
- "Provided batch size does not match converter "
- "batch size: 2 vs 6"}
- : Status::OK()},
+ TestParams{
+ /*params_shape=*/{6},
+ /*indices_shape=*/{2},
+ /*indices=*/{1, 3},
+ /*axis=*/0,
+ /*expected_output_shape=*/{2},
+ /*expected_output=*/{2, 4},
+ /*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "TensorRT does not allow "
+ "manipulation of the batch dimension, at my_gather"}
+ : Status::OK(),
+ /*runtime_status=*/Status::OK(),
+ /*add_index_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::INVALID_ARGUMENT, "Batch size doesn't match for "
+ "tensor indices: Provided batch size does not match "
+ "converter batch size: 2 vs 6"}
+ : Status::OK()
+ },
// Vector indices, and output rank is rank(params).
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -5832,6 +5855,7 @@
/*expected_output_shape=*/{1, 1, 2, 1},
/*expected_output=*/{1, 4},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -5841,6 +5865,7 @@
/*expected_output_shape=*/{1, 1, 1, 3},
/*expected_output=*/{4, 5, 6},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
// Indices with rank>1, and output rank is rank(params) + rank(indices) -
// 1
@@ -5852,6 +5877,7 @@
/*expected_output_shape=*/{1, 1, 2, 1, 1},
/*expected_output=*/{1, 4},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -5861,6 +5887,7 @@
/*expected_output_shape=*/{1, 1, 2, 1, 1},
/*expected_output=*/{2, 5},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -5870,6 +5897,7 @@
/*expected_output_shape=*/{1, 1, 2, 1, 1},
/*expected_output=*/{3, 6},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -5879,6 +5907,7 @@
/*expected_output_shape=*/{1, 1, 2, 1, 3},
/*expected_output=*/{3, 1, 2, 6, 4, 5},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 3, 2},
@@ -5888,6 +5917,7 @@
/*expected_output_shape=*/{1, 3, 1, 2, 2},
/*expected_output=*/{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5},
/*params_is_tensor=*/true,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{1, 2, 3},
@@ -5897,6 +5927,7 @@
/*expected_output_shape=*/{1, 2, 3},
/*expected_output=*/{1, 2, 3, 4, 5, 6},
/*params_is_tensor=*/false,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{3, 2},
@@ -5906,6 +5937,7 @@
/*expected_output_shape=*/{1, 2, 2},
/*expected_output=*/{1, 2, 3, 4},
/*params_is_tensor=*/false,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{2, 3},
@@ -5915,6 +5947,7 @@
/*expected_output_shape=*/{1, 1, 2, 3},
/*expected_output=*/{1, 2, 3, 4, 5, 6},
/*params_is_tensor=*/false,
+ /*indices_is_tensor=*/true,
},
TestParams{
/*params_shape=*/{3, 2},
@@ -5924,22 +5957,75 @@
/*expected_output_shape=*/{2, 2, 2},
/*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
/*params_is_tensor=*/false,
+ /*indices_is_tensor=*/true,
+ },
+ // Test cases in which indices constant
+ TestParams{
+ /*params_shape=*/{1, 1, 2, 3},
+ /*indices_shape=*/{1, 1},
+ /*indices=*/{0},
+ /*axis=*/3,
+ /*expected_output_shape=*/{1, 1, 2, 1, 1},
+ /*expected_output=*/{1, 4},
+ /*params_is_tensor=*/true,
+ /*indices_is_tensor=*/false,
+ },
+ // Test cases in which both input and indices constant
+ TestParams{
+ /*params_shape=*/{1, 2, 3},
+ /*indices_shape=*/{1},
+ /*indices=*/{0},
+ /*axis=*/0,
+ /*expected_output_shape=*/{1, 2, 3},
+ /*expected_output=*/{1, 2, 3, 4, 5, 6},
+ /*params_is_tensor=*/false,
+ /*indices_is_tensor=*/false,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+ " batch size of 1 when params and indices are both tensors or both"
+ " constants."}
+ : Status::OK()
+
+ },
+ TestParams{
+ /*params_shape=*/{3, 2},
+ /*indices_shape=*/{2, 2},
+ /*indices=*/{0, 2, 1, 0},
+ /*axis=*/0,
+ /*expected_output_shape=*/{2, 2, 2},
+ /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2},
+ /*params_is_tensor=*/false,
+ /*indices_is_tensor=*/false,
+ /*conversion_status=*/trt_mode_ == TrtTestMode::kImplicitBatch
+ ? Status{error::UNIMPLEMENTED, "Params and indices must have a"
+ " batch size of 1 when params and indices are both tensors or both"
+ " constants."}
+ : Status::OK()
},
};
for (auto p : test_params) {
- Reset();
- if (p.params_is_tensor) {
- AddTestTensor("params", p.params_shape, params_input);
- } else {
- AddTestWeights("params", p.params_shape, params_input, tf_type_);
+ Reset();
+
+ if (p.params_is_tensor) {
+ AddTestTensor("params", p.params_shape, params_input);
+ } else {
+ AddTestWeights("params", p.params_shape, params_input, tf_type_);
+ }
+
+ if (p.indices_is_tensor) {
+ AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
+ p.add_index_status);
+ } else {
+ std::vector<int> indices_shape(p.indices_shape);
+ AddTestWeights("indices", indices_shape, p.indices, DT_INT32);
+ }
+
+ AddTestWeights<int32>("axis", {1}, {p.axis});
+ TestOpConverter("my_gather", node_def, p.expected_output_shape,
+ p.conversion_status, p.runtime_status,
+ ElementsAreArray(p.expected_output));
}
- AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
- p.add_index_status);
- AddTestWeights<int32>("axis", {1}, {p.axis});
- TestOpConverter("my_gather", node_def, p.expected_output_shape, p.status,
- p.runtime_status, ElementsAreArray(p.expected_output));
- }
}
template <typename OpType>