Revert "Merge pull request #55655 from meena-at-work:meenakshiv/tftrt-specify-device-during-conversion"
[TFTRT]: Revert device placement changes as it causes
issues on graphs with nodes that are not runnable on
GPUs.
This reverts commit 0842ad13ec13732c132f8b5bc8f6768adfdeef12, reversing
changes made to dd8983737842368dd4481a2e6ca5b6a66a1706c8.
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py
index 962a531..ac1c18a 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert.py
@@ -1131,7 +1131,6 @@
self._calibration_input_fn = None
self._converted = False
- self._device = None
self._build_called_once = False
self._calibrated = False
@@ -1241,17 +1240,6 @@
"""
assert not self._converted
- if self._device is None:
- # Creating an empty tensor to fetch queried device
- self._device = array_ops.zeros([]).device
- if not self._device: # if device unspecified, defaulting to GPU 0
- self._device = "gpu:0"
- if "gpu" not in self._device.lower():
- raise ValueError(f"Specified device is not a GPU: {self._device}")
-
- logging.info(f"Placing imported graph from "
- f"`{self._input_saved_model_dir}` on device: {self._device}")
-
if (self._need_calibration and not calibration_input_fn):
raise ValueError("Should specify calibration_input_fn because INT8 "
"calibration is needed")
@@ -1263,50 +1251,39 @@
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
+ grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
- frozen_graph_def = frozen_func.graph.as_graph_def()
+ # Add a collection 'train_op' so that Grappler knows the outputs.
+ fetch_collection = meta_graph_pb2.CollectionDef()
+ for array in frozen_func.inputs + frozen_func.outputs:
+ fetch_collection.node_list.value.append(array.name)
+ grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ fetch_collection)
- # Clear any prior device assignments
- for node in frozen_graph_def.node:
- node.device = ""
-
- with ops.Graph().as_default() as graph, ops.device(self._device):
- importer.import_graph_def(frozen_graph_def, name="")
- grappler_meta_graph_def = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
-
- # Add a collection 'train_op' so that Grappler knows the outputs.
- fetch_collection = meta_graph_pb2.CollectionDef()
- for array in frozen_func.inputs + frozen_func.outputs:
- fetch_collection.node_list.value.append(array.name)
- grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
- fetch_collection)
-
- # Run TRT optimizer in Grappler to convert the graph.
- self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
- # If a function is converted, then the TF context contains the original
- # function while the converted_graph_def contains the converted function.
- # Remove the original function from the TF context in this case.
- for f in self._converted_graph_def.library.function:
- while context.context().has_function(f.signature.name):
- logging.info("Removing original function %s from the context",
- f.signature.name)
- context.context().remove_function(f.signature.name)
- # This also adds the converted functions to the context.
- self._converted_func = wrap_function.function_from_graph_def(
- self._converted_graph_def,
- [tensor.name for tensor in frozen_func.inputs],
- [tensor.name for tensor in frozen_func.outputs])
- # Reconstruct the output signatures using the ones from original model.
- self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
- func.graph.structured_outputs,
- self._converted_func.graph.structured_outputs)
- # Copy structured input signature from original function (used during
- # serialization)
- self._converted_func.graph.structured_input_signature = (
- func.structured_input_signature)
-
- self._converted_func = self._rebuild_func(self._converted_func)
+ # Run TRT optimizer in Grappler to convert the graph.
+ self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
+ # If a function is converted, then the TF context contains the original
+ # function while the converted_graph_def contains the converted function.
+ # Remove the original function from the TF context in this case.
+ for f in self._converted_graph_def.library.function:
+ while context.context().has_function(f.signature.name):
+ logging.info("Removing original function %s from the context",
+ f.signature.name)
+ context.context().remove_function(f.signature.name)
+ # This also adds the converted functions to the context.
+ self._converted_func = wrap_function.function_from_graph_def(
+ self._converted_graph_def,
+ [tensor.name for tensor in frozen_func.inputs],
+ [tensor.name for tensor in frozen_func.outputs])
+ # Reconstruct the output signatures using the ones from original model.
+ self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
+ func.graph.structured_outputs,
+ self._converted_func.graph.structured_outputs)
+ # Copy structured input signature from original function (used during
+ # serialization)
+ self._converted_func.graph.structured_input_signature = (
+ func.structured_input_signature)
if self._need_calibration:
# Execute calibration here only if not in dynamic shape mode.
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
index a7d073e..4f7714c 100644
--- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -26,7 +26,6 @@
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member
from tensorflow.core.framework import graph_pb2
-from tensorflow.python.framework import config
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import def_function
@@ -1066,78 +1065,6 @@
mock_save.save.assert_called_once_with(
mock.ANY, mock.ANY, mock.ANY, options=options)
- @parameterized.named_parameters([
- ("NoDeviceAssignment", None),
- ("GPU1", "GPU:1"),
- ])
- @test_util.run_v2_only
- def testTrtGraphConverter_DevicePlacement(self, device_id):
- """Test case for trt_convert.TrtGraphConverter()."""
-
- gpus = config.list_physical_devices("GPU")
- if len(gpus) < 2:
- self.skipTest("Expected at least 2 GPUs but found {} GPUs".format(
- len(gpus)))
-
- np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
- np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-
- # Create a model and save it.
- input_saved_model_dir = self.mkdtemp()
- root = self._GetModelForV2()
- save.save(root, input_saved_model_dir,
- {_SAVED_MODEL_SIGNATURE_KEY: root.run})
-
- converter = self._CreateConverterV2(
- input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
-
- converted_model = None
- # Specify device on which converted model should be placed
- with ops.device(device_id):
- converted_model = converter.convert()
-
- # Verify that TRT engine op has the correct device.
- self._CheckTrtOps(converter._converted_func)
-
- actual_device_id = self._GetUniqueTRTEngineOp(
- converter._converted_graph_def).device
-
- expected_device_id = None
- if device_id is not None:
- expected_device_id = device_id
- else:
- expected_device_id = "GPU:0"
-
- self.assertTrue(expected_device_id.lower() in actual_device_id.lower())
-
- del converter
- gc.collect() # Force GC to destroy the TRT engine cache.
-
- @test_util.run_v2_only
- def testTrtGraphConverter_DevicePlacementOnCPU(self):
- """Test case for trt_convert.TrtGraphConverter()."""
-
- np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
- np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
-
- # Create a model and save it.
- input_saved_model_dir = self.mkdtemp()
- root = self._GetModelForV2()
- save.save(root, input_saved_model_dir,
- {_SAVED_MODEL_SIGNATURE_KEY: root.run})
-
- # Run TRT conversion.
- converter = self._CreateConverterV2(
- input_saved_model_dir, precision_mode=trt_convert.TrtPrecisionMode.FP32)
-
- converted_model = None
- # Specify device on which converted model should be placed
- with self.assertRaisesRegex(ValueError, r"Specified device is not a GPU"):
- with ops.device("CPU"):
- converted_model = converter.convert()
-
- del converter
- gc.collect() # Force GC to destroy the TRT engine cache.
if __name__ == "__main__" and is_tensorrt_enabled():
test.main()