Revert "Merge pull request #55655 from meena-at-work:meenakshiv/tftrt-specify-device-during-conversion"

[TFTRT]: Revert device placement changes as it causes
issues on graphs with nodes that are not runnable on
GPUs.

This reverts commit 0842ad13ec13732c132f8b5bc8f6768adfdeef12, reversing
changes made to dd8983737842368dd4481a2e6ca5b6a66a1706c8.
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py
index 962a531..ac1c18a 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert.py
@@ -1131,7 +1131,6 @@
     self._calibration_input_fn = None
 
     self._converted = False
-    self._device = None
     self._build_called_once = False
     self._calibrated = False
 
@@ -1241,17 +1240,6 @@
     """
     assert not self._converted
 
-    if self._device is None:
-      # Creating an empty tensor to fetch queried device
-      self._device = array_ops.zeros([]).device
-      if not self._device:  # if device unspecified, defaulting to GPU 0
-        self._device = "gpu:0"
-      if "gpu" not in self._device.lower():
-        raise ValueError(f"Specified device is not a GPU: {self._device}")
-
-    logging.info(f"Placing imported graph from "
-                 f"`{self._input_saved_model_dir}` on device: {self._device}")
-
     if (self._need_calibration and not calibration_input_fn):
       raise ValueError("Should specify calibration_input_fn because INT8 "
                        "calibration is needed")
@@ -1263,50 +1251,39 @@
                                   self._input_saved_model_tags)
     func = self._saved_model.signatures[self._input_saved_model_signature_key]
     frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
+    grappler_meta_graph_def = saver.export_meta_graph(
+        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
 
-    frozen_graph_def = frozen_func.graph.as_graph_def()
+    # Add a collection 'train_op' so that Grappler knows the outputs.
+    fetch_collection = meta_graph_pb2.CollectionDef()
+    for array in frozen_func.inputs + frozen_func.outputs:
+      fetch_collection.node_list.value.append(array.name)
+    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+        fetch_collection)
 
-    # Clear any prior device assignments
-    for node in frozen_graph_def.node:
-      node.device = ""
-
-    with ops.Graph().as_default() as graph, ops.device(self._device):
-      importer.import_graph_def(frozen_graph_def, name="")
-      grappler_meta_graph_def = saver.export_meta_graph(
-          graph_def=graph.as_graph_def(), graph=graph)
-
-      # Add a collection 'train_op' so that Grappler knows the outputs.
-      fetch_collection = meta_graph_pb2.CollectionDef()
-      for array in frozen_func.inputs + frozen_func.outputs:
-        fetch_collection.node_list.value.append(array.name)
-      grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
-          fetch_collection)
-
-      # Run TRT optimizer in Grappler to convert the graph.
-      self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
-      # If a function is converted, then the TF context contains the original
-      # function while the converted_graph_def contains the converted function.
-      # Remove the original function from the TF context in this case.
-      for f in self._converted_graph_def.library.function:
-        while context.context().has_function(f.signature.name):
-          logging.info("Removing original function %s from the context",
-                       f.signature.name)
-          context.context().remove_function(f.signature.name)
-      # This also adds the converted functions to the context.
-      self._converted_func = wrap_function.function_from_graph_def(
-          self._converted_graph_def,
-          [tensor.name for tensor in frozen_func.inputs],
-          [tensor.name for tensor in frozen_func.outputs])
-      # Reconstruct the output signatures using the ones from original model.
-      self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
-          func.graph.structured_outputs,
-          self._converted_func.graph.structured_outputs)
-      # Copy structured input signature from original function (used during
-      # serialization)
-      self._converted_func.graph.structured_input_signature = (
-          func.structured_input_signature)
-
-    self._converted_func = self._rebuild_func(self._converted_func)
+    # Run TRT optimizer in Grappler to convert the graph.
+    self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
+    # If a function is converted, then the TF context contains the original
+    # function while the converted_graph_def contains the converted function.
+    # Remove the original function from the TF context in this case.
+    for f in self._converted_graph_def.library.function:
+      while context.context().has_function(f.signature.name):
+        logging.info("Removing original function %s from the context",
+                     f.signature.name)
+        context.context().remove_function(f.signature.name)
+    # This also adds the converted functions to the context.
+    self._converted_func = wrap_function.function_from_graph_def(
+        self._converted_graph_def,
+        [tensor.name for tensor in frozen_func.inputs],
+        [tensor.name for tensor in frozen_func.outputs])
+    # Reconstruct the output signatures using the ones from original model.
+    self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
+        func.graph.structured_outputs,
+        self._converted_func.graph.structured_outputs)
+    # Copy structured input signature from original function (used during
+    # serialization)
+    self._converted_func.graph.structured_input_signature = (
+        func.structured_input_signature)
 
     if self._need_calibration:
       # Execute calibration here only if not in dynamic shape mode.
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index a7d073e..4f7714c 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -26,7 +26,6 @@
 from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
 from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance  # pylint: disable=g-importing-member
 from tensorflow.core.framework import graph_pb2
-from tensorflow.python.framework import config
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.compiler.tensorrt import trt_convert
 from tensorflow.python.eager import def_function
@@ -1066,78 +1065,6 @@
       mock_save.save.assert_called_once_with(
           mock.ANY, mock.ANY, mock.ANY, options=options)
 
-  @parameterized.named_parameters([
-      ("NoDeviceAssignment", None),
-      ("GPU1", "GPU:1"),
-  ])
-  @test_util.run_v2_only
-  def testTrtGraphConverter_DevicePlacement(self, device_id):
-    """Test case for trt_convert.TrtGraphConverter()."""
-
-    gpus = config.list_physical_devices("GPU")
-    if len(gpus) < 2:
-      self.skipTest("Expected at least 2 GPUs but found {} GPUs".format(
-          len(gpus)))
-
-    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-
-    # Create a model and save it.
-    input_saved_model_dir = self.mkdtemp()
-    root = self._GetModelForV2()
-    save.save(root, input_saved_model_dir,
-              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
-
-    converter = self._CreateConverterV2(
-        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
-
-    converted_model = None
-    # Specify device on which converted model should be placed
-    with ops.device(device_id):
-      converted_model = converter.convert()
-
-    # Verify that TRT engine op has the correct device.
-    self._CheckTrtOps(converter._converted_func)
-
-    actual_device_id = self._GetUniqueTRTEngineOp(
-        converter._converted_graph_def).device
-
-    expected_device_id = None
-    if device_id is not None:
-      expected_device_id = device_id
-    else:
-      expected_device_id = "GPU:0"
-
-    self.assertTrue(expected_device_id.lower() in actual_device_id.lower())
-
-    del converter
-    gc.collect()  # Force GC to destroy the TRT engine cache.
-
-  @test_util.run_v2_only
-  def testTrtGraphConverter_DevicePlacementOnCPU(self):
-    """Test case for trt_convert.TrtGraphConverter()."""
-
-    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-
-    # Create a model and save it.
-    input_saved_model_dir = self.mkdtemp()
-    root = self._GetModelForV2()
-    save.save(root, input_saved_model_dir,
-              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
-
-    # Run TRT conversion.
-    converter = self._CreateConverterV2(
-        input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
-
-    converted_model = None
-    # Specify device on which converted model should be placed
-    with self.assertRaisesRegex(ValueError, r"Specified device is not a GPU"):
-      with ops.device("CPU"):
-        converted_model = converter.convert()
-
-    del converter
-    gc.collect()  # Force GC to destroy the TRT engine cache.
 
 if __name__ == "__main__" and is_tensorrt_enabled():
   test.main()