Support tf1 models for 'convert_with_tensorrt'.
PiperOrigin-RevId: 301891524
Change-Id: If6c3f692a4763cf171c6e585c4986f52c732ee1a
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 8b951be..6e60e58 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -772,16 +772,33 @@
# not installed
from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top
- params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
- max_workspace_size_bytes=args.max_workspace_size_bytes,
- precision_mode=args.precision_mode,
- minimum_segment_size=args.minimum_segment_size)
- converter = trt.TrtGraphConverterV2(
- input_saved_model_dir=args.dir,
- input_saved_model_tags=args.tag_set.split(','),
- conversion_params=params)
- converter.convert()
- converter.save(output_saved_model_dir=args.output_dir)
+ if not args.convert_tf1_model:
+ params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
+ max_workspace_size_bytes=args.max_workspace_size_bytes,
+ precision_mode=args.precision_mode,
+ minimum_segment_size=args.minimum_segment_size)
+ converter = trt.TrtGraphConverterV2(
+ input_saved_model_dir=args.dir,
+ input_saved_model_tags=args.tag_set.split(','),
+ conversion_params=params)
+ try:
+ converter.convert()
+ except Exception as e:
+ raise RuntimeError(
+ '{}. Try passing "--convert_tf1_model=True".'.format(e))
+ converter.save(output_saved_model_dir=args.output_dir)
+ else:
+ trt.create_inference_graph(
+ None,
+ None,
+ max_batch_size=1,
+ max_workspace_size_bytes=args.max_workspace_size_bytes,
+ precision_mode=args.precision_mode,
+ minimum_segment_size=args.minimum_segment_size,
+ is_dynamic_op=True,
+ input_saved_model_dir=args.dir,
+ input_saved_model_tags=args.tag_set.split(','),
+ output_saved_model_dir=args.output_dir)
def aot_compile_cpu(args):
@@ -1010,6 +1027,11 @@
default=3,
help=('the minimum number of nodes required for a subgraph to be replaced'
'in a TensorRT node'))
+ parser_convert_with_tensorrt.add_argument(
+ '--convert_tf1_model',
+ type=bool,
+ default=False,
+ help='support TRT conversion for TF1 models')
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)