TF-TRT Create execution context with device memory if needed
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 8b5f90a..7e2da11 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -859,8 +859,9 @@
// for it.
mutex_lock lock(engine_context->mu);
nvinfer1::IExecutionContext* execution_context;
- TF_RETURN_IF_ERROR(
- engine_context->GetExecutionContext(trt_context_idx, &execution_context));
+ bool has_device_memory;
+ TF_RETURN_IF_ERROR(engine_context->GetExecutionContext(
+ trt_context_idx, &execution_context, &has_device_memory));
if (VLOG_IS_ON(2)) {
VLOG(2) << "Selected execution context: " << trt_context_idx;
@@ -885,11 +886,12 @@
->GpuStreamMemberHack()));
ContextDeviceMemory context_device_memory;
- // Allocate device memory for the TensorRT engine execution. The device
- // memory will be released when context_device_memory goes out of scope.
- TF_RETURN_IF_ERROR(
- context_device_memory.AllocateDeviceMemory(execution_context, allocator));
-
+ if (!has_device_memory) {
+ // Allocate device memory for the TensorRT engine execution. The device
+ // memory will be released when context_device_memory goes out of scope.
+ TF_RETURN_IF_ERROR(context_device_memory.AllocateDeviceMemory(
+ execution_context, allocator));
+ }
// Enqueue the TensorRT engine for execution.
return TrtEnqueue(execution_context, buffers, *stream, use_implicit_batch_,
num_batch);
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
index c9437f7..e058043 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
@@ -37,9 +37,24 @@
using absl::StrCat;
ExecutionContext ExecutionContext::Create(nvinfer1::ICudaEngine* cuda_engine) {
+ bool has_int32_output = false;
+ for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
+ if (!cuda_engine->bindingIsInput(i) &&
+ cuda_engine->getBindingDataType(i) == nvinfer1::DataType::kINT32) {
+ has_int32_output = true;
+ break;
+ }
+ }
+ if (!IS_TRT_VERSION_GE(8, 0, 0, 0) && has_int32_output) {
+ // TODO(nvbugs/3390469): Remove this workaround when the bug is fixed.
+ nvinfer1::IExecutionContext* execution_context =
+ cuda_engine->createExecutionContext();
+ return ExecutionContext(execution_context, true);
+ }
+
nvinfer1::IExecutionContext* execution_context =
cuda_engine->createExecutionContextWithoutDeviceMemory();
- return ExecutionContext(execution_context);
+ return ExecutionContext(execution_context, false);
}
Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h b/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
index 4ab8010..05b5cef 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
@@ -26,9 +26,15 @@
// execution context when the object goes out of scope.
class ExecutionContext : public TrtUniquePtrType<nvinfer1::IExecutionContext> {
public:
- ExecutionContext(nvinfer1::IExecutionContext* context)
- : TrtUniquePtrType<nvinfer1::IExecutionContext>(context) {}
+ ExecutionContext(nvinfer1::IExecutionContext* context, bool has_memory)
+ : TrtUniquePtrType<nvinfer1::IExecutionContext>(context),
+ has_device_memory_(has_memory) {}
static ExecutionContext Create(nvinfer1::ICudaEngine* cuda_engine);
+
+ bool HasDeviceMemory() { return has_device_memory_; }
+
+ private:
+ bool has_device_memory_;
};
}; // namespace tensorrt
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
index 96ecfe3..758fe9a 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
@@ -133,7 +133,8 @@
mutex mu;
TrtUniquePtrType<nvinfer1::ICudaEngine> cuda_engine;
- Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx)
+ Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx,
+ bool* has_device_memory)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
if (idx >= execution_contexts.size()) {
return errors::Internal("Requested engine context with index ", idx,
@@ -141,6 +142,7 @@
"contexts are present.");
}
*exec_ctx = execution_contexts[idx].get();
+ *has_device_memory = execution_contexts[idx].HasDeviceMemory();
return Status::OK();
}
diff --git a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
index 33bd527..eff1eb3 100644
--- a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
@@ -21,6 +21,7 @@
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -64,5 +65,54 @@
return ["TRTEngineOp_0"]
+class ShapeOutputWithSingleInputProfile(ShapeOutputTest):
+ """Same as the previous test, but with a single input profile. """
+
+ def setUp(self):
+ super().setUp()
+ self.DisableNonTrtOptimizers()
+
+
+ def GetParams(self):
+ return self.BuildParamsWithMask(
+ self.GraphFn, dtypes.float32,
+ [[2, 2, 5, 3]], [[4]],
+ extra_inputs=[],
+ extra_outputs=[],
+ input_mask=[[False, True, True, True]],
+ output_mask=[[True]])
+
+
+class ShapeOutputWithSingleInputAndReshape(trt_test.TfTrtIntegrationTestBase):
+ """Similar to the previous test, but the ShapeOp output is reshaped to 2D,
+ which is not compatible with shape tensor. """
+
+ def setUp(self):
+ super().setUp()
+ self.DisableNonTrtOptimizers()
+
+ def GraphFn(self, x):
+ q = 2 * x + 1
+ q = array_ops.shape(q)
+ q = gen_array_ops.reshape(q, [2, 2])
+ q = math_ops.cast(q, dtypes.float32)
+ q = self.trt_incompatible_op(q)
+ q = q * 2 + q * q
+ return array_ops.identity(q, name="output_0")
+
+ def GetParams(self):
+ return self.BuildParamsWithMask(
+ self.GraphFn, dtypes.float32,
+ [[2, 2, 5, 3]], [[2, 2]],
+ extra_inputs=[],
+ extra_outputs=[],
+ input_mask=[[False, True, True, True]],
+ output_mask=[[True, True]])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Returns the expected engines to build."""
+ return ["TRTEngineOp_0", "TRTEngineOp_1"]
+
+
if __name__ == "__main__":
test.main()