TF-TRT Create execution context with device memory if needed
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 8b5f90a..7e2da11 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -859,8 +859,9 @@
   // for it.
   mutex_lock lock(engine_context->mu);
   nvinfer1::IExecutionContext* execution_context;
-  TF_RETURN_IF_ERROR(
-      engine_context->GetExecutionContext(trt_context_idx, &execution_context));
+  bool has_device_memory;
+  TF_RETURN_IF_ERROR(engine_context->GetExecutionContext(
+      trt_context_idx, &execution_context, &has_device_memory));
 
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "Selected execution context: " << trt_context_idx;
@@ -885,11 +886,12 @@
                                                 ->GpuStreamMemberHack()));
 
   ContextDeviceMemory context_device_memory;
-  // Allocate device memory for the TensorRT engine execution. The device
-  // memory will be released when context_device_memory goes out of scope.
-  TF_RETURN_IF_ERROR(
-      context_device_memory.AllocateDeviceMemory(execution_context, allocator));
-
+  if (!has_device_memory) {
+    // Allocate device memory for the TensorRT engine execution. The device
+    // memory will be released when context_device_memory goes out of scope.
+    TF_RETURN_IF_ERROR(context_device_memory.AllocateDeviceMemory(
+        execution_context, allocator));
+  }
   // Enqueue the TensorRT engine for execution.
   return TrtEnqueue(execution_context, buffers, *stream, use_implicit_batch_,
                     num_batch);
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
index c9437f7..e058043 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc
@@ -37,9 +37,24 @@
 using absl::StrCat;
 
 ExecutionContext ExecutionContext::Create(nvinfer1::ICudaEngine* cuda_engine) {
+  bool has_int32_output = false;
+  for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
+    if (!cuda_engine->bindingIsInput(i) &&
+        cuda_engine->getBindingDataType(i) == nvinfer1::DataType::kINT32) {
+      has_int32_output = true;
+      break;
+    }
+  }
+  if (!IS_TRT_VERSION_GE(8, 0, 0, 0) && has_int32_output) {
+    // TODO(nvbugs/3390469): Remove this workaround when the bug is fixed.
+    nvinfer1::IExecutionContext* execution_context =
+        cuda_engine->createExecutionContext();
+    return ExecutionContext(execution_context, true);
+  }
+
   nvinfer1::IExecutionContext* execution_context =
       cuda_engine->createExecutionContextWithoutDeviceMemory();
-  return ExecutionContext(execution_context);
+  return ExecutionContext(execution_context, false);
 }
 
 Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h b/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
index 4ab8010..05b5cef 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h
@@ -26,9 +26,15 @@
 // execution context when the object goes out of scope.
 class ExecutionContext : public TrtUniquePtrType<nvinfer1::IExecutionContext> {
  public:
-  ExecutionContext(nvinfer1::IExecutionContext* context)
-      : TrtUniquePtrType<nvinfer1::IExecutionContext>(context) {}
+  ExecutionContext(nvinfer1::IExecutionContext* context, bool has_memory)
+      : TrtUniquePtrType<nvinfer1::IExecutionContext>(context),
+        has_device_memory_(has_memory) {}
   static ExecutionContext Create(nvinfer1::ICudaEngine* cuda_engine);
+
+  bool HasDeviceMemory() { return has_device_memory_; }
+
+ private:
+  bool has_device_memory_;
 };
 
 };  // namespace tensorrt
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
index 96ecfe3..758fe9a 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
@@ -133,7 +133,8 @@
   mutex mu;
   TrtUniquePtrType<nvinfer1::ICudaEngine> cuda_engine;
 
-  Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx)
+  Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx,
+                             bool* has_device_memory)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
     if (idx >= execution_contexts.size()) {
       return errors::Internal("Requested engine context with index ", idx,
@@ -141,6 +142,7 @@
                               "contexts are present.");
     }
     *exec_ctx = execution_contexts[idx].get();
+    *has_device_memory = execution_contexts[idx].HasDeviceMemory();
     return Status::OK();
   }
 
diff --git a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
index 33bd527..eff1eb3 100644
--- a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py
@@ -21,6 +21,7 @@
 from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
@@ -64,5 +65,54 @@
       return ["TRTEngineOp_0"]
 
 
+class ShapeOutputWithSingleInputProfile(ShapeOutputTest):
+  """Same as the previous test, but with a single input profile. """
+
+  def setUp(self):
+    super().setUp()
+    self.DisableNonTrtOptimizers()
+
+
+  def GetParams(self):
+    return self.BuildParamsWithMask(
+        self.GraphFn, dtypes.float32,
+        [[2, 2, 5, 3]], [[4]],
+        extra_inputs=[],
+        extra_outputs=[],
+        input_mask=[[False, True, True, True]],
+        output_mask=[[True]])
+
+
+class ShapeOutputWithSingleInputAndReshape(trt_test.TfTrtIntegrationTestBase):
+  """Similar to the previous test, but the ShapeOp output is reshaped to 2D,
+  which is not compatible with shape tensor. """
+
+  def setUp(self):
+    super().setUp()
+    self.DisableNonTrtOptimizers()
+
+  def GraphFn(self, x):
+    q = 2 * x + 1
+    q = array_ops.shape(q)
+    q = gen_array_ops.reshape(q, [2, 2])
+    q = math_ops.cast(q, dtypes.float32)
+    q = self.trt_incompatible_op(q)
+    q = q * 2 + q * q
+    return array_ops.identity(q, name="output_0")
+
+  def GetParams(self):
+    return self.BuildParamsWithMask(
+        self.GraphFn, dtypes.float32,
+        [[2, 2, 5, 3]], [[2, 2]],
+        extra_inputs=[],
+        extra_outputs=[],
+        input_mask=[[False, True, True, True]],
+        output_mask=[[True, True]])
+
+  def ExpectedEnginesToBuild(self, run_params):
+    """Returns the expected engines to build."""
+    return ["TRTEngineOp_0", "TRTEngineOp_1"]
+
+
 if __name__ == "__main__":
   test.main()