| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Utilities to test TF-TensorRT integration.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from collections import namedtuple |
| import errno |
| import gc |
| import itertools |
| import os |
| import re |
| import shutil |
| import tempfile |
| import warnings |
| |
| import numpy as np |
| import six |
| |
| from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.compiler.tensorrt import trt_convert |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import graph_io |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.saved_model import builder |
| from tensorflow.python.saved_model import load |
| from tensorflow.python.saved_model import loader |
| from tensorflow.python.saved_model import save |
| from tensorflow.python.saved_model import signature_constants |
| from tensorflow.python.saved_model import signature_def_utils |
| from tensorflow.python.saved_model import tag_constants |
| from tensorflow.python.saved_model import utils |
| from tensorflow.python.tools import saved_model_utils |
| from tensorflow.python.training.tracking import tracking |
| from tensorflow.python.util import nest |
| |
| TfTrtIntegrationTestParams = namedtuple( |
| "TfTrtIntegrationTestParams", |
| [ |
| # A function that creates the TF graph for testing. |
| "graph_fn", |
| # A list of specifications for input tensors. |
| "input_specs", |
| # A list of specifications for output tensors. |
| "output_specs", |
| # A list of list of input shapes. Each shape must match the |
| # corresponding element in `input_specs`. |
| "input_dims", |
| # A list of list of expected output shapes. Each shape must match the |
| # corresponding element in `output_specs`. |
| "expected_output_dims" |
| ]) |
| |
| RunParams = namedtuple( |
| "RunParams", |
| [ |
| # Whether to run the conversion online with RewriterConfig, or offline |
| # with TrtGraphConverter. |
| "convert_online", |
| "precision_mode", |
| "dynamic_engine", |
| "use_calibration", |
| "test_name", |
| # Is this test for TF 2.0? |
| "is_v2", |
| ]) |
| |
| FP32 = "FP32" |
| FP16 = "FP16" |
| INT8 = "INT8" |
| PRECISION_MODES = [FP32, FP16, INT8] |
| |
| |
| def IsQuantizationMode(mode): |
| return mode == "INT8" |
| |
| |
| def IsQuantizationWithCalibration(params): |
| return IsQuantizationMode(params.precision_mode) and params.use_calibration |
| |
| |
| class GraphState(object): |
| ORIGINAL = 0 |
| CALIBRATE = 1 |
| INFERENCE = 2 |
| |
| |
| class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): |
| """Class to test Tensorflow-TensorRT integration.""" |
| |
| @property |
| def trt_incompatible_op(self): |
| return math_ops.erf |
| |
| @property |
| def precision_modes(self): |
| return ["FP32", "FP16", "INT8"] |
| |
| # str is bytes in py2, but unicode in py3. |
| def _ToUnicode(self, s): |
| if six.PY2: |
| if isinstance(s, unicode): |
| return s |
| return s.decode("utf-8") |
| else: |
| if isinstance(s, str): |
| return s |
| return s.decode("utf-8") |
| |
| def _ToBytes(self, s): |
| if six.PY2: |
| if isinstance(s, unicode): |
| return s.encode("utf-8") |
| return s |
| else: |
| if isinstance(s, str): |
| return s.encode("utf-8") |
| return s |
| |
| def _ToString(self, s): |
| if six.PY2: |
| if isinstance(s, unicode): |
| return s.encode("utf-8") |
| return s |
| else: |
| if isinstance(s, str): |
| return s |
| return s.decode("utf-8") |
| |
| def __init__(self, methodName="runTest"): # pylint: disable=invalid-name |
| super(TfTrtIntegrationTestBase, self).__init__(methodName) |
| self._trt_test_params = None |
| |
| def setUp(self): |
| """Setup method.""" |
| super(TfTrtIntegrationTestBase, self).setUp() |
| warnings.simplefilter("always") |
| |
| def _GetTensorSpec(self, shape, mask, dtype, name): |
| # Set dimension i to None if mask[i] == False |
| assert len(shape) == len(mask) |
| new_shape = [s if m else None for s, m in zip(shape, mask)] |
| return tensor_spec.TensorSpec(new_shape, dtype, name) |
| |
| def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes): |
| """Build test parameters. |
| |
| The input_shapes and output_shapes arguments are known (static) shapes that |
| can be used to generate test data. To define the model, we also specify |
| corresponding input/output TensoSpecs. These are defined using the shape |
| arguments. For each input tensor we define: |
| |
| input_spec = [None] + input_shape[1:] |
| |
| and similarly for output shapes. This means that we leave the first (batch) |
| dimension unknown, the rest is just copied from the shapes arg. |
| |
| Args: |
| graph_fn: The function to build the graph. |
| dtype: The element type. |
| input_shapes: The input shapes. |
| output_shapes: The output shapes. |
| |
| Returns: |
| The test parameters. |
| """ |
| |
| input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes] |
| output_mask = [ |
| [False] + [True] * (len(shape) - 1) for shape in output_shapes |
| ] |
| |
| return self.BuildParamsWithMask(graph_fn, dtype, input_shapes, |
| output_shapes, input_mask, output_mask, [], |
| []) |
| |
| def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes, |
| input_mask, output_mask, extra_inputs, extra_outputs): |
| """Build test parameters with static or dynamic input shapes. |
| |
| To define dynamic shapes give a boolean mask that describes which |
| dimensions to treat as known. The values in input_mask are interpreted the |
| following way: |
| - True: known dim (use the corresponding value from input_shapes) |
| - False: unknown dim (replace the corresponding value from input_shapes |
| with None) |
| For example, to define the first two dimension with unknown size use |
| input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]]. |
| |
| Args: |
| graph_fn: The function to build the graph. |
| dtype: The element type. |
| input_shapes: The input shapes. |
| output_shapes: The output shapes. |
| input_mask: The input shape masks. |
| output_mask: the output shape masks. |
| extra_inputs: list of additional input shapes |
| extra_outputs: list of additional outputs shapes |
| |
| Returns: |
| The test parameters. |
| """ |
| |
| def _ValidateShapes(shapes): |
| # Make sure all the shapes are fully specified. |
| for shape in shapes: |
| assert all(shape) |
| |
| _ValidateShapes(input_shapes) |
| _ValidateShapes(output_shapes) |
| |
| assert len(input_mask) == len(input_shapes) |
| assert len(output_mask) == len(output_shapes) |
| for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs): |
| assert len(input_shapes) == len(extra_in_shape) |
| assert len(output_shapes) == len(extra_out_shape) |
| |
| return TfTrtIntegrationTestParams( |
| graph_fn=graph_fn, |
| input_specs=[ |
| self._GetTensorSpec(shape, mask, dtype, "input_%d" % i) |
| for i, (shape, mask) in enumerate(zip(input_shapes, input_mask)) |
| ], |
| output_specs=[ |
| self._GetTensorSpec(shape, mask, dtype, "output_%d" % i) |
| for i, (shape, mask) in enumerate(zip(output_shapes, output_mask)) |
| ], |
| input_dims=[input_shapes] + extra_inputs, |
| expected_output_dims=[output_shapes] + extra_outputs) |
| |
| def GetParams(self): |
| """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" |
| raise NotImplementedError() |
| |
| def GetConversionParams(self, run_params): |
| """Return a TrtConversionParams for test.""" |
| batch_list = [] |
| for dims_list in self._GetParamsCached().input_dims: |
| assert dims_list |
| # Each list of shapes should have same batch size. |
| input_batches = [dims[0] for dims in dims_list] |
| assert max(input_batches) == min(input_batches) |
| batch_list.append(input_batches[0]) |
| conversion_params = trt_convert.TrtConversionParams( |
| # We use the minimum of all the batch sizes, so when multiple different |
| # input shapes are provided it'll always create new engines in the |
| # cache, and we can therefore test the cache behavior. |
| rewriter_config_template=None, |
| max_workspace_size_bytes=1 << 25, |
| precision_mode=run_params.precision_mode, |
| minimum_segment_size=2, |
| is_dynamic_op=run_params.dynamic_engine, |
| maximum_cached_engines=1, |
| use_calibration=run_params.use_calibration, |
| max_batch_size=min(batch_list)) |
| return conversion_params |
| |
| def GetTrtRewriterConfig(self, |
| run_params, |
| conversion_params, |
| disable_non_trt_optimizers=False, |
| use_implicit_batch=True): |
| rewriter_config = trt_convert.get_tensorrt_rewriter_config( |
| conversion_params=conversion_params, |
| is_v2=run_params.is_v2, |
| disable_non_trt_optimizers=disable_non_trt_optimizers) |
| for optimizer in rewriter_config.custom_optimizers: |
| if optimizer.name == "TensorRTOptimizer": |
| optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch |
| return rewriter_config |
| |
| def ShouldRunTest(self, run_params): |
| """Whether to run the test.""" |
| # Ensure use_calibration=True in case of INT8 precision |
| return (run_params.use_calibration or not IsQuantizationMode( |
| run_params.precision_mode)), "test either calibration or non-INT8" |
| |
| def ExpectedEnginesToBuild(self, run_params): |
| """Return the expected engines to build, implemented by subclass.""" |
| raise NotImplementedError() |
| |
| def ExpectedAbsoluteTolerance(self, run_params): |
| """The absolute tolerance to compare floating point results.""" |
| return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 |
| |
| def ExpectedRelativeTolerance(self, run_params): |
| """The relative tolerance to compare floating point results.""" |
| return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 |
| |
| def _GetParamsCached(self): |
| if self._trt_test_params is None: |
| self._trt_test_params = self.GetParams() |
| return self._trt_test_params |
| |
| def _GetGPUOptions(self): |
| gpu_options = config_pb2.GPUOptions() |
| gpu_options.allow_growth = True |
| return gpu_options |
| |
| def _GetConfigProto(self, run_params, graph_state): |
| """Get config proto based on specific settings.""" |
| conversion_params = self.GetConversionParams(run_params) |
| if graph_state == GraphState.INFERENCE and run_params.convert_online: |
| rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params) |
| graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) |
| else: |
| graph_options = config_pb2.GraphOptions() |
| if conversion_params.rewriter_config_template is not None: |
| graph_options.rewrite_options.CopyFrom( |
| conversion_params.rewriter_config_template) |
| |
| config = config_pb2.ConfigProto( |
| gpu_options=self._GetGPUOptions(), graph_options=graph_options) |
| return config |
| |
| def _GetFeedNames(self): |
| params = self._GetParamsCached() |
| # Construct the feeds tensor names by appending :0 to the node names. |
| return [spec.name + ":0" for spec in params.input_specs] |
| |
| def _GetFetchNames(self): |
| params = self._GetParamsCached() |
| # Construct the fetches tensor names by appending :0 to the node names. |
| return [spec.name + ":0" for spec in params.output_specs] |
| |
| def _GetFeedDict(self, inputs_data): |
| return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)} |
| |
| def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2): |
| """Run given graphdef multiple times using TF 1.x runtime.""" |
| params = self._GetParamsCached() |
| fetches = self._GetFetchNames() |
| g = ops.Graph() |
| with g.as_default(): |
| with self.session(graph=g, config=config, use_gpu=True) as sess: |
| loader.load(sess, [tag_constants.SERVING], saved_model_dir) |
| vals = [] |
| # Run for each input(s) shape |
| for expected_shapes, current_input_data in zip( |
| params.expected_output_dims, inputs_data): |
| val = None |
| for _ in range(num_runs): |
| new_val = sess.run(fetches, self._GetFeedDict(current_input_data)) |
| self.assertEqual(len(expected_shapes), len(new_val)) |
| for expected_shape, actual_val in zip(expected_shapes, new_val): |
| self.assertEqual(list(expected_shape), list(actual_val.shape)) |
| if val is not None: |
| # Some ops may have nondeterministic output. E.g. Conv2D may use |
| # winograd algorithm. So we set atol/rtol be larger than 1.e-06. |
| self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) |
| val = new_val |
| vals.append(val) |
| return vals |
| |
| def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2): |
| """Run given graphdef multiple times using TF 2.0 runtime.""" |
| params = self._GetParamsCached() |
| root = load.load(saved_model_dir) |
| func = root.signatures[ |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] |
| results = [] |
| for expected_shapes, current_input_data in zip(params.expected_output_dims, |
| inputs_data): |
| val = None |
| for _ in range(num_runs): |
| feed_dict = { |
| params.input_specs[i].name: current_input_data[i] |
| for i in range(len(params.input_specs)) |
| } |
| new_val = func(**feed_dict) |
| assert isinstance(new_val, dict) |
| # The key of the output map is always like output_i. |
| new_val = [new_val[key] for key in sorted(new_val)] |
| # Each element is an eager Tensor, and accessing individual elements is |
| # very expensive, so we convert them to a numpy array first. |
| new_val = [v.numpy() for v in new_val] |
| self.assertEqual(len(expected_shapes), len(new_val)) |
| for expected_shape, actual_val in zip(expected_shapes, new_val): |
| self.assertEqual(list(expected_shape), list(actual_val.shape)) |
| if val is not None: |
| # Some ops may have nondeterministic output. E.g. Conv2D may use |
| # winograd algorithm. So we set atol/rtol be larger than 1.e-06. |
| self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) |
| val = new_val |
| results.append(val) |
| |
| return results |
| |
| def _RunGraph(self, |
| run_params, |
| saved_model_dir, |
| inputs_data, |
| config, |
| graph_state, |
| num_runs=2): |
| params = self._GetParamsCached() |
| for data in inputs_data: |
| assert len(params.input_specs) == len(data) |
| |
| if run_params.is_v2: |
| results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state, |
| num_runs) |
| gc.collect() # Force GC to destroy the TRT engine cache. |
| return results |
| return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) |
| |
| def _CreateConverter(self, run_params, saved_model_dir, session_config, |
| conversion_params): |
| """Return a TrtGraphConverter.""" |
| if run_params.is_v2: |
| return trt_convert.TrtGraphConverterV2( |
| input_saved_model_dir=saved_model_dir, |
| conversion_params=conversion_params) |
| return trt_convert.TrtGraphConverter( |
| input_saved_model_dir=saved_model_dir, |
| session_config=session_config, |
| max_batch_size=conversion_params.max_batch_size, |
| max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, |
| precision_mode=conversion_params.precision_mode, |
| minimum_segment_size=conversion_params.minimum_segment_size, |
| is_dynamic_op=conversion_params.is_dynamic_op, |
| maximum_cached_engines=conversion_params.maximum_cached_engines, |
| use_calibration=conversion_params.use_calibration) |
| |
| def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data): |
| """Return trt converted graphdef in INT8 mode.""" |
| conversion_params = self.GetConversionParams(run_params) |
| logging.info(conversion_params) |
| assert conversion_params.precision_mode == "INT8" |
| assert conversion_params.is_dynamic_op |
| assert conversion_params.maximum_cached_engines == 1 |
| assert conversion_params.use_calibration |
| |
| # We only support calibrating single engine. |
| # TODO(aaroey): fix this. |
| assert len(inputs_data) == 1 |
| |
| session_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) |
| logging.info("Running calibration graph, config:\n%s", str(session_config)) |
| |
| converter = self._CreateConverter(run_params, saved_model_dir, |
| session_config, conversion_params) |
| int8_gdef = converter.convert() |
| self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef, |
| GraphState.CALIBRATE) |
| |
| converter.calibrate( |
| fetch_names=self._GetFetchNames(), |
| num_runs=5, |
| feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0])) |
| trt_saved_model_dir = self._GetSavedModelDir(run_params, |
| GraphState.CALIBRATE) |
| converter.save(trt_saved_model_dir) |
| return trt_saved_model_dir |
| |
| def _GetInferGraph(self, run_params, saved_model_dir): |
| """Return trt converted graphdef.""" |
| conversion_params = self.GetConversionParams(run_params) |
| logging.info(conversion_params) |
| |
| session_config = self._GetConfigProto(run_params, GraphState.INFERENCE) |
| logging.info("Creating TRT graph for inference, config\n%s", |
| str(session_config)) |
| converter = self._CreateConverter(run_params, saved_model_dir, |
| session_config, conversion_params) |
| converter.convert() |
| |
| if trt_convert.is_explicit_batch_mode_enabled( |
| conversion_params.rewriter_config_template): |
| logging.info("Using build mode") |
| |
| def _BuildInputFn(): |
| for shapes in self._GetParamsCached().input_dims: |
| yield [np.zeros(x).astype(np.float32) for x in shapes] |
| |
| converter.build(input_fn=_BuildInputFn) |
| |
| trt_saved_model_dir = self._GetSavedModelDir(run_params, |
| GraphState.INFERENCE) |
| converter.save(trt_saved_model_dir) |
| return trt_saved_model_dir |
| |
| def _GetGraphStateLabel(self, graph_state): |
| if graph_state == GraphState.ORIGINAL: |
| return "Original" |
| elif graph_state == GraphState.CALIBRATE: |
| return "CalibEngine" |
| elif graph_state == GraphState.INFERENCE: |
| return "InferEngine" |
| else: |
| return "UnknownState" |
| |
| def _WriteGraph(self, run_params, gdef, graph_state): |
| temp_dir = os.getenv("TRT_TEST_TMPDIR") |
| if not temp_dir: |
| return |
| |
| graph_name = ( |
| self.__class__.__name__ + "_" + run_params.test_name + "_" + |
| self._GetGraphStateLabel(graph_state) + ".pbtxt") |
| logging.info("Writing graph to %s/%s", temp_dir, graph_name) |
| graph_io.write_graph(gdef, temp_dir, graph_name) |
| |
| # Remove the graph sequence number prefix from the name only if the name has |
| # a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a |
| # prefix exists. |
| def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix): |
| match = re.search(r"TRTEngineOp_\d+_", name) |
| has_prefix = match and name.startswith(match.group(0)) |
| assert (not expecting_prefix) or has_prefix |
| if has_prefix: |
| parts = name.split("_", maxsplit=2) |
| assert len(parts) == 3 |
| return parts[0] + "_" + parts[2] |
| return name |
| |
| def _RemoveGraphSequenceNumber(self, name): |
| return self._RemoveGraphSequenceNumberImpl(name, True) |
| |
| def _MayRemoveGraphSequenceNumber(self, name): |
| return self._RemoveGraphSequenceNumberImpl(name, False) |
| |
| def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): |
| old_to_new_node_map = { |
| self._ToString(node.name): self._ToString(node.name) |
| for node in original_gdef.node |
| } |
| for engine_name, node_names in expected_engines.items(): |
| for node_name in node_names: |
| old_to_new_node_map[node_name] = engine_name |
| name_to_node_map = { |
| self._ToString(node.name): node for node in original_gdef.node |
| } |
| |
| def _InputName(inp): |
| inp = self._ToString(inp) |
| prefix = "" |
| if inp[0] == "^": |
| prefix = "^" |
| inp = inp[1:] |
| parts = inp.split(":") |
| if len(parts) > 1 and parts[-1].isdigit(): |
| inp = inp[:-len(parts[-1]) - 1] |
| return (prefix, inp) |
| |
| # Compute the expected mapping from each node to its input nodes. |
| expected_input_map = {} |
| removed_const_nodes = set([ |
| self._ToString(node.name) |
| for node in original_gdef.node |
| if node.op == "Const" |
| ]) |
| for node in original_gdef.node: |
| name_str = self._ToString(node.name) |
| target_node_name = old_to_new_node_map[name_str] |
| is_engine_op = (target_node_name != name_str) |
| if target_node_name not in expected_input_map: |
| expected_input_map[target_node_name] = set() |
| input_set = expected_input_map[target_node_name] |
| for inp in node.input: |
| (prefix, inp_name) = _InputName(inp) |
| mapped_input = old_to_new_node_map[inp_name] |
| # Add the input only if it's outside the segment (note that it could be |
| # in a different engine). |
| if not is_engine_op or (mapped_input != target_node_name and |
| name_to_node_map[inp_name].op != "Const"): |
| input_set.add(prefix + mapped_input) |
| if mapped_input in removed_const_nodes: |
| removed_const_nodes.remove(mapped_input) |
| # Remove const nodes that have no outputs. |
| expected_input_map = { |
| k: v |
| for k, v in expected_input_map.items() |
| if k not in removed_const_nodes |
| } |
| |
| # Compute the actual mapping from each node to its input nodes. If a cast |
| # op doesn't exist in the original graph, we replace the use of the cast op |
| # with the input of the op. This allows the verification to handle the case |
| # where the TF-TRT bridge splits a cast op into a chain of two cast ops. |
| new_cast_op_name_to_node_map = { |
| node.name: node |
| for node in converted_gdef.node |
| if (node.name not in old_to_new_node_map and node.op == "Cast") |
| } |
| actual_input_map = {} |
| for node in converted_gdef.node: |
| name_str = node.name |
| # Only nodes from the original graph or TRTEngineOp nodes are added as |
| # keys to the map. |
| if node.op == "TRTEngineOp": |
| name_str = self._RemoveGraphSequenceNumber(name_str) |
| elif name_str not in old_to_new_node_map: |
| continue |
| actual_input_map[name_str] = set() |
| input_set = actual_input_map[name_str] |
| for inp in node.input: |
| (prefix, node_name) = _InputName(inp) |
| node_name = self._MayRemoveGraphSequenceNumber(node_name) |
| if node_name in new_cast_op_name_to_node_map: |
| (prefix, node_name) = _InputName( |
| new_cast_op_name_to_node_map[node_name].input[0]) |
| input_set.add(prefix + node_name) |
| |
| self.assertEqual( |
| expected_input_map, |
| actual_input_map, |
| msg="\nexpected:\n%s\nvs actual:\n%s" % |
| (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) |
| |
| def _GetGraphDef(self, run_params, gdef_or_saved_model_dir): |
| if isinstance(gdef_or_saved_model_dir, str): |
| if run_params.is_v2: |
| root = load.load(gdef_or_saved_model_dir) |
| func = root.signatures[ |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] |
| gdef = func.graph.as_graph_def() |
| # Manually unref the loaded saved model and force GC to destroy the TRT |
| # engine cache after load(). There is currently a reference cycle in 2.0 |
| # which prevents auto deletion of the resource. |
| # TODO(laigd): fix this. |
| del func |
| del root |
| gc.collect() |
| return gdef |
| return saved_model_utils.get_meta_graph_def( |
| gdef_or_saved_model_dir, tag_constants.SERVING).graph_def |
| assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef) |
| return gdef_or_saved_model_dir |
| |
| def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify, |
| graph_state): |
| expected_engines = self.ExpectedEnginesToBuild(run_params) |
| num_engines = 0 |
| functions = [f.signature.name for f in gdef_to_verify.library.function] |
| for node in gdef_to_verify.node: |
| if node.op == "TRTEngineOp": |
| logging.info("Found TRTEngineOp: " + node.name) |
| num_engines += 1 |
| segment_funcdef_name = node.attr["segment_func"].func.name |
| function_name = node.name + "_native_segment" |
| is_dynamic_engine = not node.attr["static_engine"].b |
| self.assertNotEmpty(segment_funcdef_name, node.name) |
| self.assertIn(function_name, functions) |
| if not IsQuantizationWithCalibration and not is_dynamic_engine: |
| self.assertTrue(len(node.attr["serialized_segment"].s), node.name) |
| self.assertIn( |
| self._RemoveGraphSequenceNumber(node.name), expected_engines) |
| self.assertEqual( |
| self._ToBytes(run_params.precision_mode), |
| node.attr["precision_mode"].s, node.name) |
| |
| self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, |
| node.name) |
| self.assertEqual(node.attr["use_calibration"].b, |
| run_params.use_calibration, node.name) |
| |
| has_calibration_data = len(node.attr["calibration_data"].s) |
| if (IsQuantizationWithCalibration(run_params) and |
| graph_state == GraphState.INFERENCE): |
| self.assertTrue(has_calibration_data, node.name) |
| else: |
| self.assertFalse(has_calibration_data, node.name) |
| if graph_state == GraphState.ORIGINAL: |
| self.assertEqual(0, num_engines) |
| else: |
| self.assertEqual(num_engines, len(expected_engines)) |
| if isinstance(expected_engines, dict): |
| self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify) |
| # TODO(aaroey): consider verifying the corresponding TF function. |
| |
| def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, |
| graph_state): |
| if graph_state == GraphState.ORIGINAL: |
| return |
| expected_engines = self.ExpectedEnginesToBuild(run_params) |
| all_op_names = [node.name for node in gdef_to_verify.node] |
| trt_op_names = [ |
| node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp" |
| ] |
| for func in gdef_to_verify.library.function: |
| if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", |
| func.signature.name): |
| for node in func.node_def: |
| all_op_names.append(node.name) |
| if node.op == "TRTEngineOp": |
| trt_op_names.append(node.name) |
| # Remove the function name prefix. |
| def _Canonicalize(names): |
| return set(self._ToString(name.split("/")[-1]) for name in names) |
| # Remove the graph sequence number prefix from all the names. |
| def _RemoveGraphSequenceNumber(names): |
| return set(self._RemoveGraphSequenceNumber(name) for name in names) |
| |
| all_op_names = _Canonicalize(all_op_names) |
| trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names)) |
| |
| if isinstance(expected_engines, dict): |
| # For simplicity we don't verify the connections inside the engine in |
| # 2.0, but we still make sure that the converted ops are gone from the |
| # graph. |
| unexpected_names = set(nest.flatten(expected_engines.values())) |
| self.assertEmpty( |
| [name for name in unexpected_names if name in all_op_names]) |
| expected_engines = set(expected_engines.keys()) |
| |
| self.assertEqual(set(expected_engines), trt_op_names) |
| |
| def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, |
| gdef_or_saved_model_dir_to_verify, graph_state): |
| original_gdef = self._GetGraphDef(run_params, |
| original_gdef_or_saved_model_dir) |
| gdef_to_verify = self._GetGraphDef(run_params, |
| gdef_or_saved_model_dir_to_verify) |
| self._WriteGraph(run_params, gdef_to_verify, graph_state) |
| if run_params.is_v2: |
| self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify, |
| graph_state) |
| else: |
| self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify, |
| graph_state) |
| |
| def _GetSavedModelDir(self, run_params, graph_state): |
| test_tmpdir = os.getenv("TRT_TEST_TMPDIR") |
| if test_tmpdir: |
| saved_model_dir = os.path.join( |
| test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name + |
| "_" + self._GetGraphStateLabel(graph_state)) |
| try: |
| # For TF 1.x we need to make sure the output directory doesn't exist |
| # before exporting the saved model. |
| shutil.rmtree(saved_model_dir) |
| except OSError as e: |
| if e.errno != errno.ENOENT: |
| raise |
| return saved_model_dir |
| return tempfile.mkdtemp(dir=self.get_temp_dir()) |
| |
| def _MakeSavedModelV1(self, run_params): |
| """Write the saved model as an input for testing.""" |
| params = self._GetParamsCached() |
| g = ops.Graph() |
| with g.as_default(): |
| inputs = [] |
| for spec in params.input_specs: |
| inp = array_ops.placeholder( |
| dtype=spec.dtype, shape=spec.shape, name=spec.name) |
| inputs.append(inp) |
| outputs = params.graph_fn(*inputs) |
| if not isinstance(outputs, list) and not isinstance(outputs, tuple): |
| outputs = [outputs] |
| |
| signature_def = signature_def_utils.build_signature_def( |
| inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs}, |
| outputs={out.op.name: utils.build_tensor_info(out) for out in outputs}, |
| method_name=signature_constants.PREDICT_METHOD_NAME) |
| |
| saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) |
| saved_model_builder = builder.SavedModelBuilder(saved_model_dir) |
| with self.session( |
| graph=g, config=self._GetConfigProto(run_params, |
| GraphState.ORIGINAL)) as sess: |
| saved_model_builder.add_meta_graph_and_variables( |
| sess, [tag_constants.SERVING], |
| signature_def_map={ |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: |
| signature_def |
| }) |
| saved_model_builder.save() |
| return saved_model_dir |
| |
| def _MakeSavedModelV2(self, run_params): |
| params = self._GetParamsCached() |
| root = tracking.AutoTrackable() |
| root.run = def_function.function( |
| params.graph_fn, input_signature=params.input_specs) |
| saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) |
| logging.info("Saving input SavedModel to %s", saved_model_dir) |
| save.save(root, saved_model_dir, |
| {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run}) |
| return saved_model_dir |
| |
| def _MakeSavedModel(self, run_params): |
| if run_params.is_v2: |
| return self._MakeSavedModelV2(run_params) |
| return self._MakeSavedModelV1(run_params) |
| |
| def RunTest(self, run_params): |
| should_run, reason_for_skipping = self.ShouldRunTest(run_params) |
| if not should_run: |
| return self.skipTest(reason_for_skipping) |
| |
| saved_model_dir = self._MakeSavedModel(run_params) |
| |
| np.random.seed(12345) # Fix the seed so the test is deterministic. |
| inputs_data = [] |
| input_specs = self._GetParamsCached().input_specs |
| for dim_list in self._GetParamsCached().input_dims: |
| assert len(input_specs) == len(dim_list) |
| current_input_data = [] |
| for spec, np_shape in zip(input_specs, dim_list): |
| np_dtype = spec.dtype.as_numpy_dtype() |
| # Multiply the input by some constant to avoid all zeros input for |
| # integer types. |
| scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 |
| # TODO(laigd): add debug options. E.g. we can set the input data to be |
| # continuous natural numbers: |
| # seq = np.arange(np.prod(np_shape)) |
| # seq.resize(np_shape) |
| # current_inputs_data.append(scale * seq.astype(np_dtype)) |
| data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) |
| if run_params.is_v2: |
| with ops.device("/GPU:0"): |
| data = ops.convert_to_tensor(data) |
| current_input_data.append(data) |
| inputs_data.append(current_input_data) |
| |
| # Verify original graph. |
| self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, |
| GraphState.ORIGINAL) |
| |
| # Run original graph without trt to get reference result. |
| config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL) |
| logging.info("Running original graph w/o trt, config:\n%s", |
| str(config_no_trt)) |
| ref_result = self._RunGraph(run_params, saved_model_dir, inputs_data, |
| config_no_trt, GraphState.ORIGINAL) |
| |
| # Run calibration if necessary. |
| if IsQuantizationWithCalibration(run_params): |
| infer_saved_model_dir = self._GetCalibratedInferGraph( |
| run_params, saved_model_dir, inputs_data) |
| self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, |
| GraphState.INFERENCE) |
| elif not run_params.convert_online: |
| infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) |
| self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, |
| GraphState.INFERENCE) |
| else: |
| infer_saved_model_dir = saved_model_dir |
| |
| # Run inference. |
| infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE) |
| logging.info("Running final inference graph, config:\n%s", |
| str(infer_config)) |
| result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, |
| infer_config, GraphState.INFERENCE) |
| self.assertAllClose( |
| ref_result, |
| result, |
| atol=self.ExpectedAbsoluteTolerance(run_params), |
| rtol=self.ExpectedRelativeTolerance(run_params)) |
| |
| def testIdempotence(self): |
| # Test that applying tensorrt optimizer or offline conversion tools multiple |
| # times to the same graph will result in same graph. |
| # |
| # TODO(aaroey): implement this. |
| pass |
| |
| |
| def _GetTestConfigsV1(): |
| """Returns the config combinations to run the test.""" |
| convert_online, convert_offline = True, False |
| dynamic_engine, static_engine = True, False |
| use_calibration, no_calibration = True, False |
| |
| # Add all possible test cases and let the derived test class to decide |
| # whether to run specific ones with ShouldRunTest(). |
| # |
| # Note: INT8 without calibration behaves like FP32/FP16. |
| opts = list( |
| itertools.product([FP32, FP16, INT8], [convert_online, convert_offline], |
| [dynamic_engine, static_engine], [no_calibration])) |
| # We always run calibration with offline tool. |
| # TODO(aaroey): static calibration engine is not supported yet. |
| opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) |
| return opts |
| |
| |
| def _GetTestConfigsV2(): |
| """Returns the config combinations to run the test.""" |
| convert_offline = False |
| # TODO(laigd): add support for static_engine. |
| dynamic_engine = True |
| # TODO(laigd): add support for calibration. |
| no_calibration = False |
| |
| # Add all possible test cases and let the derived test class to decide |
| # whether to run specific ones with ShouldRunTest(). |
| # |
| # Note: |
| # - In TF2.0 the conversion always produce dynamic engine, and we don't test |
| # the offline mode here. |
| # - For simplicity we don't test online conversion which requires setting the |
| # Grappler config in default eager context. |
| # - INT8 without calibration behaves like FP32/FP16. |
| opts = list( |
| itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine], |
| [no_calibration])) |
| # We always run calibration with offline tool. |
| # TODO(aaroey): INT8+calibration is not supported yet in V2. |
| # opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) |
| return opts |
| |
| |
| def _GetTest(run_params): |
| """Gets a single test method based on the parameters.""" |
| |
| def _Test(self): |
| logging.info( |
| "Running test %s with parameters: convert_online=%s, " |
| "precision_mode=%s, dynamic_engine=%s", run_params.test_name, |
| run_params.convert_online, run_params.precision_mode, |
| run_params.dynamic_engine) |
| self.RunTest(run_params) |
| |
| return _Test |
| |
| |
| def _AddTestsFor(test_class, is_v2): |
| """Adds test methods to TfTrtIntegrationTestBase for specific TF version.""" |
| opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1() |
| for (precision_mode, convert_online, dynamic_engine, use_calibration) in opts: |
| conversion = "OnlineConversion" if convert_online else "OfflineConversion" |
| engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" |
| calibration_type = "UseCalibration" if use_calibration else "NoCalibration" |
| test_name = "%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt", |
| conversion, engine_type, precision_mode, |
| calibration_type) |
| run_params = RunParams( |
| convert_online=convert_online, |
| precision_mode=precision_mode, |
| dynamic_engine=dynamic_engine, |
| test_name=test_name, |
| use_calibration=use_calibration, |
| is_v2=is_v2) |
| if is_v2: |
| setattr(test_class, test_name, |
| test_util.run_v2_only(_GetTest(run_params))) |
| else: |
| setattr(test_class, test_name, |
| test_util.run_v1_only("", _GetTest(run_params))) |
| |
| |
| def _AddTests(test_class): |
| """Adds test methods to TfTrtIntegrationTestBase.""" |
| _AddTestsFor(test_class, is_v2=False) |
| _AddTestsFor(test_class, is_v2=True) |
| |
| |
| if is_tensorrt_enabled(): |
| os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" |
| _AddTests(TfTrtIntegrationTestBase) |