blob: 61ecd79beb2f1b9c033f3f24835d9d364c1b0a75 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to test TF-TensorRT integration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import os
import tempfile
import numpy as np
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util.lazy_loader import LazyLoader
_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
gen_trt_ops = LazyLoader(
"gen_trt_ops", globals(),
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
class TrtConvertTest(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration python API."""
# Use a small max_workspace_size for tests so they don't consume too much GPU
# memory.
_TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
def testGetTensorrtRewriterConfig(self):
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
if not is_tensorrt_enabled():
return
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_batch_size=128,
max_workspace_size_bytes=1234,
precision_mode="INT8",
minimum_segment_size=10,
is_dynamic_op=True,
maximum_cached_engines=2)
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
conversion_params=conversion_params)
self.assertEqual(["constfold", "layout", "constfold"],
rewriter_cfg.optimizers)
self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
rewriter_cfg.meta_optimizer_iterations)
trt_optimizer = None
for optimizer in rewriter_cfg.custom_optimizers:
if optimizer.name == "TensorRTOptimizer":
self.assertTrue(trt_optimizer is None)
trt_optimizer = optimizer
self.assertTrue(trt_optimizer is not None)
for key in [
"minimum_segment_size", "max_batch_size", "is_dynamic_op",
"max_workspace_size_bytes", "precision_mode", "maximum_cached_engines"
]:
self.assertTrue(key in trt_optimizer.parameter_map)
self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
self.assertEqual(1234,
trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
self.assertEqual(
trt_convert._to_bytes("INT8"),
trt_optimizer.parameter_map["precision_mode"].s)
self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
def _GetConfigProto(self):
"""Get ConfigProto for session creation."""
config = config_pb2.ConfigProto(
gpu_options=config_pb2.GPUOptions(allow_growth=True))
return config
@classmethod
def _GetGraph(cls, inp, var):
"""Get the graph for testing."""
# The graph computes (input+1)^2, it looks like:
#
# input (Placeholder) v1 (Variable)
# | \ /
# \ +
# \ / \
# * |
# \ /
# +
# |
# output (Identity)
add = inp + var
mul = inp * add
add = mul + add
out = array_ops.identity(add, name="output")
return out
def _GetModelForV2(self):
class SimpleModel(tracking.AutoTrackable):
def __init__(self):
self.v = None
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
])
def run(self, inp):
if self.v is None:
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
return TrtConvertTest._GetGraph(inp, self.v)
return SimpleModel()
def _GetGraphForV1(self):
g = ops.Graph()
with g.as_default():
with g.device("/GPU:0"):
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
out = TrtConvertTest._GetGraph(inp, var)
return g, var, inp, out
def _GetGraphDef(self):
"""Get the graph def for testing."""
g, var, _, _ = self._GetGraphForV1()
with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer)
graph_def = graph_util.convert_variables_to_constants(
sess, g.as_graph_def(add_shapes=True), ["output"])
node_name_to_op = {node.name: node.op for node in graph_def.node}
self.assertEqual(
{
"v1": "Const",
"add/ReadVariableOp": "Identity",
"input": "Placeholder",
"add": "AddV2",
"mul": "Mul",
"add_1": "AddV2",
"output": "Identity"
}, node_name_to_op)
return graph_def
def _WriteInputSavedModel(self, input_saved_model_dir):
"""Write the saved model as an input for testing."""
g, var, inp, out = self._GetGraphForV1()
signature_def = signature_def_utils.build_signature_def(
inputs={"myinput": utils.build_tensor_info(inp)},
outputs={"myoutput": utils.build_tensor_info(out)},
method_name=signature_constants.PREDICT_METHOD_NAME)
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer)
saved_model_builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={_SAVED_MODEL_SIGNATURE_KEY: signature_def})
saved_model_builder.save()
def _ConvertGraph(self,
input_saved_model_dir=None,
output_saved_model_dir=None,
need_calibration=False,
max_batch_size=1,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
use_function_backup=False):
"""Helper method to convert a GraphDef or SavedModel using TF-TRT."""
converter = trt_convert.TrtGraphConverter(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
nodes_blacklist=None if input_saved_model_dir else ["output"],
session_config=self._GetConfigProto(),
max_batch_size=max_batch_size,
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration
else trt_convert.TrtPrecisionMode.FP32),
minimum_segment_size=minimum_segment_size,
is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines,
use_function_backup=use_function_backup)
output_graph_def = converter.convert()
if need_calibration:
class CalibrationData(object):
def __init__(self):
self._data = 0
def next(self):
self._data += 1
return {"input:0": [[[self._data]]]}
output_graph_def = converter.calibrate(
fetch_names=["output:0"],
num_runs=10,
feed_dict_fn=CalibrationData().next)
if output_saved_model_dir is not None:
converter.save(output_saved_model_dir=output_saved_model_dir)
return output_graph_def
def _TestTrtGraphConverter(self,
input_saved_model_dir=None,
output_saved_model_dir=None,
need_calibration=False,
is_dynamic_op=False):
"""General method to test trt_convert.TrtGraphConverter()."""
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration,
is_dynamic_op=is_dynamic_op,
use_function_backup=need_calibration)
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
self.assertEqual(
{
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
if need_calibration:
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
self.assertNotEmpty(trt_engine_nodes)
for node in trt_engine_nodes:
self.assertTrue(len(node.attr["calibration_data"].s))
# Run the calibrated graph.
# TODO(laigd): consider having some input where the answer is different.
with ops.Graph().as_default():
importer.import_graph_def(graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
for test_data in range(10):
self.assertEqual(
(test_data + 1.0)**2,
sess.run("output:0", feed_dict={"input:0": [[[test_data]]]}))
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_BasicConversion(self):
"""Test case for trt_convert.TrtGraphConverter()."""
if not is_tensorrt_enabled():
return
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
self._WriteInputSavedModel(input_saved_model_dir)
for need_calibration in [False, True]:
# Use GraphDef as input.
self._TestTrtGraphConverter()
# Use SavedModel as input.
output_saved_model_dir = os.path.join(
tmp_dir, "out_dir1%s" % ("_int8" if need_calibration else ""))
self._TestTrtGraphConverter(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration)
def _CreateConverterV2(self, input_saved_model_dir):
return trt_convert.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_convert.TrtPrecisionMode.FP32,
is_dynamic_op=True,
maximum_cached_engines=2,
use_function_backup=False))
@test_util.run_v2_only
def testTrtGraphConverter_BasicConversion_v2(self):
"""Test case for trt_convert.TrtGraphConverter()."""
if not is_tensorrt_enabled():
return
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
# Create a model and save it.
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
root = self._GetModelForV2()
expected_output = root.run(np_input)
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
# Run TRT conversion.
converter = self._CreateConverterV2(input_saved_model_dir)
converted_func = converter.convert()
def _check_trt_ops(graph_def):
trt_op_names = [
node.name for node in graph_def.node if node.op == "TRTEngineOp"
]
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
self.assertEqual(1, len(trt_op_names))
self.assertIn("TRTEngineOp_0", trt_op_names[0])
# Verify the converted GraphDef and ConcreteFunction.
self.assertIsInstance(converted_func, def_function.Function)
converted_concrete_func = converted_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
_check_trt_ops(converted_concrete_func.graph.as_graph_def())
# Save the converted model without any TRT engine cache.
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
converter.save(output_saved_model_dir)
unexpected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
self.assertFalse(os.path.exists(unexpected_asset_file))
# Run the converted function to populate the engine cache.
output_with_trt = converted_func(np_input)
self.assertEqual(1, len(output_with_trt))
self.assertAllClose(
expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6)
# Save the converted model again with serialized engine cache.
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
converter.save(output_saved_model_dir)
expected_asset_file = os.path.join(
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
# Load and verify the converted model.
#
# TODO(laigd): the name of then new input_signature of the
# `root_with_trt.run` function is empty string (originaly was None),
# investigate why.
root_with_trt = load.load(output_saved_model_dir)
# TODO(laigd): `root_with_trt.run` is still using the original graph without
# trt. Consider changing that.
# _check_trt_ops(
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
_check_trt_ops(converted_signature.graph.as_graph_def())
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
# The output of running the converted signature is a dict due to
# compatibility reasons with V1 SavedModel signature mechanism.
output_with_trt = output_with_trt[output_with_trt.keys()[0]]
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
@test_util.run_v2_only
def testTrtGraphConverter_DestroyEngineCache(self):
"""Test case for trt_convert.TrtGraphConverter()."""
if not is_tensorrt_enabled():
return
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
# Create a model and save it.
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
root = self._GetModelForV2()
save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
# Run TRT conversion.
converter = self._CreateConverterV2(input_saved_model_dir)
converted_func = converter.convert()
converted_func(np_input) # Populate the TRT engine cache.
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
converter.save(output_saved_model_dir)
def _destroy_cache():
with ops.device("GPU:0"):
handle = gen_trt_ops.create_trt_engine_cache_handle(
container=trt_convert._TRT_ENGINE_CACHE_CONTAINER_NAME,
resource_name="TRTEngineOp_0")
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=False)
with self.assertRaisesRegexp(errors.NotFoundError,
r"Resource .* does not exist."):
_destroy_cache()
# Load the converted model and make sure the engine cache is populated by
# default.
root = load.load(output_saved_model_dir)
_destroy_cache()
with self.assertRaisesRegexp(errors.NotFoundError,
r"Resource .* does not exist."):
_destroy_cache()
# Load the converted model again and make sure the engine cache is destroyed
# when the model goes out of scope.
root = load.load(output_saved_model_dir)
del root
gc.collect() # Force GC to destroy the TRT engine cache.
with self.assertRaisesRegexp(errors.NotFoundError,
r"Resource .* does not exist."):
_destroy_cache()
def _TestRun(self,
sess,
batch_size,
use_function_backup=False,
expect_engine_is_run=True):
try:
result = sess.run(
"output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
self.assertAllEqual([[[4.0]]] * batch_size, result)
except errors.OpError as e:
# This should happen only when fallback path is disabled and TRT engine
# fails to run.
self.assertTrue(not use_function_backup and not expect_engine_is_run)
self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_MinimumSegmentSize(self):
if not is_tensorrt_enabled():
return
output_graph_def = self._ConvertGraph(minimum_segment_size=5)
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
self.assertEqual(
{
"add/ReadVariableOp": "Const",
"input": "Placeholder",
"add": "AddV2",
"mul": "Mul",
"add_1": "AddV2",
"output": "Identity"
}, node_name_to_op)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_DynamicOp(self):
if not is_tensorrt_enabled():
return
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
self._WriteInputSavedModel(input_saved_model_dir)
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
is_dynamic_op=True,
maximum_cached_engines=2,
use_function_backup=False) # Disallow fallback.
# Test the output GraphDef.
with ops.Graph().as_default():
importer.import_graph_def(output_graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
# Run with batch size 1, a new engine is created and cached.
self._TestRun(sess, 1)
# Run with batch size 2, a new engine is created and cached.
self._TestRun(sess, 2)
# Run with batch size 3, since the number of cached engines has reached
# the max, it should evict an old engine and create a new one.
self._TestRun(sess, 3)
# Test the output SavedModel
with ops.Graph().as_default():
with self.session(config=self._GetConfigProto()) as sess:
loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
# Run with batch size 1, a new engine is created and cached.
self._TestRun(sess, 1)
# Run with batch size 2, a new engine is created and cached.
self._TestRun(sess, 2)
# Run with batch size 3, since the number of cached engines has reached
# the max, it should evict an old engine and create a new one.
self._TestRun(sess, 3)
def _TestStaticOp(self, use_function_backup):
if not is_tensorrt_enabled():
return
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
self._WriteInputSavedModel(input_saved_model_dir)
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
maximum_cached_engines=2, # This is noop, added just for testing.
use_function_backup=use_function_backup)
# Test the output GraphDef.
with ops.Graph().as_default():
importer.import_graph_def(output_graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
# Test the output SavedModel
with ops.Graph().as_default():
with self.session(config=self._GetConfigProto()) as sess:
loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_StaticOp_NoFallback(self):
self._TestStaticOp(use_function_backup=False)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_StaticOp_WithFallback(self):
self._TestStaticOp(use_function_backup=True)
if __name__ == "__main__":
test.main()