[fx2trt] example for lowering model to trt with FX based tooling (#57298)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57298
Some of the code is borrowed from NVIDIA-AI-IOT/torch2trt https://github.com/NVIDIA-AI-IOT/torch2trt/tree/master/torch2trt.
Move fx2trt stuff to fx/experimental/fx2trt.
Add an example in fx/experimental/fx2trt/example/fx2trt_example.py that shows how we lower resnet18 to TensorRT using FX.
TODO: Include license from NVIDIA-AI-IOT/torch2trt
Test Plan: CI
Reviewed By: jackm321
Differential Revision: D28102144
fbshipit-source-id: 1a7b03e45b8ab3fcc355d097d73afeec2efc3328
diff --git a/torch/fx/experimental/fx2trt/__init__.py b/torch/fx/experimental/fx2trt/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/__init__.py
diff --git a/torch/fx/experimental/fx2trt/converter/__init__.py b/torch/fx/experimental/fx2trt/converter/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/converter/__init__.py
diff --git a/torch/fx/experimental/fx2trt/converter/vanilla_converter.py b/torch/fx/experimental/fx2trt/converter/vanilla_converter.py
new file mode 100644
index 0000000..af11070
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/converter/vanilla_converter.py
@@ -0,0 +1,237 @@
+import operator
+import torch
+import tensorrt as trt
+import numpy as np
+
+from torch.fx.experimental.fx2trt.fx2trt import tensorrt_converter, torch_dtype_to_trt
+
+
+def process_attr(submod: torch.nn.Module, name: str, size: int):
+ val = getattr(submod, name)
+ if not isinstance(val, tuple):
+ val = (val,) * size
+ return val
+
+
+def to_numpy(tensor):
+ if tensor.is_quantized:
+ tensor = tensor.dequantize()
+ return tensor.detach().contiguous().numpy()
+
+
+@tensorrt_converter(torch.nn.modules.conv.Conv2d)
+def torch_nn_modules_conv_Conv2d(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"Conv2d received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ kernel_size = process_attr(submod, "kernel_size", 2)
+ stride = process_attr(submod, "stride", 2)
+ padding = process_attr(submod, "padding", 2)
+ dilation = process_attr(submod, "dilation", 2)
+
+ kernel = to_numpy(submod.weight)
+
+ bias = trt.Weights(torch_dtype_to_trt(submod.weight.dtype))
+ if submod.bias is not None:
+ bias = to_numpy(submod.bias)
+
+ layer = network.add_convolution(
+ input=input_val,
+ num_output_maps=submod.out_channels,
+ kernel_shape=kernel_size,
+ kernel=kernel,
+ bias=bias,
+ )
+ layer.name = name
+ layer.stride = stride
+ layer.padding = padding
+ layer.dilation = dilation
+
+ if submod.groups is not None:
+ layer.num_groups = submod.groups
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.batchnorm.BatchNorm2d)
+def torch_nn_modules_batchnorm_BatchNorm2d(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ scale = to_numpy(submod.weight) / np.sqrt(
+ to_numpy(submod.running_var) + submod.eps
+ )
+ bias = (
+ to_numpy(submod.bias)
+ - to_numpy(submod.running_mean) * scale
+ )
+ power = np.ones_like(scale)
+
+ layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.functional.relu)
+@tensorrt_converter(torch.nn.modules.activation.ReLU)
+def torch_nn_modules_activation_ReLU(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"ReLU received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ layer = network.add_activation(
+ input=input_val, type=trt.ActivationType.RELU)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.pooling.MaxPool2d)
+def torch_nn_modules_pooling_MaxPool2d(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"MaxPool2d received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ kernel_size = process_attr(submod, "kernel_size", 2)
+ stride = process_attr(submod, "stride", 2)
+ padding = process_attr(submod, "padding", 2)
+ ceil_mode = submod.ceil_mode
+
+ layer = network.add_pooling(
+ input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
+
+ layer.stride = stride
+ layer.padding = padding
+ layer.name = name
+
+ if ceil_mode:
+ layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+ return layer.get_output(0)
+
+
+@tensorrt_converter(operator.add)
+@tensorrt_converter(torch.add)
+def torch_add(network, target, args, kwargs, name):
+ if len(kwargs) != 0:
+ raise RuntimeError("`out` parameter on torch.add not supported!")
+
+ assert len(args) == 2
+ if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+ raise RuntimeError("add() received an input that is not part of the TensorRT region!")
+
+ lhs_val, rhs_val = args
+
+ # TODO: broadcast
+ # https://github.com/NVIDIA-AI-IOT/torch2trt/blob/44977a94cb087fe521421802e9df12a5ac3ceb3f/torch2trt/torch2trt.py#L168
+
+ layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.SUM)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.pooling.AdaptiveAvgPool2d)
+def torch_nn_modules_pooling_AdaptiveAvgPool2d(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ output_size = process_attr(submod, "output_size", 2)
+ stride = (input_val.shape[-2] // output_size[-2], input_val.shape[-1] // output_size[-1])
+ kernel_size = stride
+ layer = network.add_pooling(
+ input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
+ layer.stride = stride
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.flatten)
+def torch_flatten(network, target, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"Flatten received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ # For trt shape we don"t have batch dim
+ start_dim = kwargs["start_dim"] - 1
+ end_dim = len(input_val.shape) if kwargs["end_dim"] == -1 else kwargs["end_dim"] - 1
+
+ assert start_dim >= 0, "Expect non negtive start_dim, this probably due to flatten batch dim."
+
+ new_shape = []
+ flatten_dim = 1
+ for i, dim in enumerate(input_val.shape):
+ if i < start_dim:
+ new_shape.append(dim)
+ elif i > end_dim:
+ new_shape.append(flatten_dim)
+ new_shape.append(dim)
+ else:
+ flatten_dim *= dim
+
+ if end_dim == len(input_val.shape):
+ new_shape.append(flatten_dim)
+
+ layer = network.add_shuffle(input_val)
+ layer.reshape_dims = tuple(new_shape)
+ layer.name = name
+ return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.linear.Linear)
+def torch_nn_modules_linear_Linear(network, submod, args, kwargs, name):
+ # args/kwargs should have already been normalized to kwargs
+ assert len(args) == 0
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"Linear received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ layer = network.add_shuffle(input_val)
+ layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+ layer.name = f"{name}_pre_shuffle"
+
+ bias = trt.Weights(torch_dtype_to_trt(submod.weight.dtype))
+ if submod.bias is not None:
+ bias = to_numpy(submod.bias)
+
+ # add fully connected
+ layer = network.add_fully_connected(
+ input=layer.get_output(0),
+ num_outputs=submod.out_features,
+ kernel=to_numpy(submod.weight),
+ bias=bias
+ )
+ layer.name = f"{name}_linear"
+
+ # reshape back
+ layer = network.add_shuffle(layer.get_output(0))
+ layer.reshape_dims = tuple(input_val.shape[:-1]) + (submod.out_features,)
+ layer.name = f"{name}_post_shuffle"
+ return layer.get_output(0)
diff --git a/torch/fx/experimental/fx2trt/example/fx2trt_example.py b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
new file mode 100644
index 0000000..f9e306f
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
@@ -0,0 +1,264 @@
+from typing import Tuple, Dict, Callable, Any
+
+import torch
+import torch.fx
+import torchvision.models as models
+import torch.fx.experimental.fx2trt.converter.vanilla_converter
+import torch.fx.passes.splitter_base as splitter_base
+import torch.fx.passes.operator_support as op_support
+import torch.fx.passes.net_min_base as net_min_base
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+
+
+# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
+# model to TensorRT via FX with existing FX based tooling. The general lowering flow
+# would be like:
+#
+# 1. Use splitter to split the model if there're ops in the model that we don't want to
+# lower to TensorRT for some reasons like the ops are not supported in TensorRT or
+# running them on other backends provides better performance.
+# 2. Lower the model (or part of the model if splitter is used) to TensorRT via fx2trt.
+#
+# For this example, we use ResNet18 as example model and split out the linear layer to
+# not run on TensorRT just to demonstrate how the splitter works. At the end of this
+# example we did a benchmark for a model (named `split_mod`) with all the ops running
+# on TensorRT execpt linear layer running on PyTorch Cuda versus a model (named `rn18`)
+# fully on PyTorch Cuda.
+
+
+# Create ResNet18 `rn18` and inputs `x`
+rn18 = models.resnet18().eval().cuda()
+x = torch.randn(5, 3, 224, 224, device="cuda")
+
+# Trace the model with FX.
+traced_rn18 = torch.fx.symbolic_trace(rn18)
+
+
+def lower_mod_to_trt(mod: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
+ """
+ Helper function that given a GraphModule `mod` and its `inputs`, build a
+ TRTModule that runs the original `mod` on TensorRT.
+ """
+ interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
+ engine, input_names, output_names = interp.run(*inputs)
+ return TRTModule(engine, input_names, output_names)
+
+
+class OpSupport(op_support.OperatorSupport):
+ """
+ This class is used by splitter to determine which nodes are supported, i.e.
+ should be split to the accelerator part (TensorRT).
+ """
+ def is_node_supported(
+ self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
+ ):
+ """
+ Here we want linear layer to not run on TensorRT. Thus, we return
+ False for linear layer and True for all other ops.
+ """
+ target = op_support.get_node_target(submodules, node)
+
+ if target == "torch.nn.modules.linear.Linear":
+ return False
+
+ return True
+
+
+class TensorRTMinimizer(net_min_base._MinimizerBase):
+ """
+ Need to define a Minimizer class for TensorRT because it's used in Splitter.
+ """
+ def __init__(
+ self,
+ module: torch.fx.GraphModule,
+ sample_input: Tuple[torch.Tensor],
+ compare_fn: Callable[[Any, Any], Tuple[float, bool]],
+ settings: net_min_base._MinimizerSettingBase = None,
+ ):
+ if settings is None:
+ settings = net_min_base._MinimizerSettingBase()
+
+ super().__init__(module, sample_input, compare_fn, settings)
+
+ def run_a(self, mod, inputs):
+ """
+ The output of this function serves as an reference.
+ """
+ mod.eval()
+ with torch.no_grad():
+ return mod(*inputs)
+
+ def run_b(self, mod, inputs):
+ """
+ Here we actually run mod on TensorRT return TensorRT result.
+ """
+ mod.eval()
+ try:
+ mod = lower_mod_to_trt(mod, inputs)
+ output = mod(*inputs)
+ except RuntimeError as e:
+ raise net_min_base.FxNetMinimizerRunFuncError(
+ f"Encounter an error when processing \n{mod.graph}\n {e}"
+ )
+ else:
+ return output
+
+
+# This in the future will be a global TensorRTSplitter and we don't need to create
+# it per example.
+class TensorRTSplitter(splitter_base._SplitterBase):
+ """
+ Splitter for TensorRT.
+ """
+ def __init__(
+ self,
+ module: torch.fx.GraphModule,
+ sample_input: Tuple[torch.Tensor],
+ operator_support: op_support.OperatorSupport = None,
+ settings: splitter_base._SplitterSettingBase = None
+ ):
+ if not operator_support:
+ operator_support = op_support.OperatorSupport()
+
+ if not settings:
+ settings = splitter_base._SplitterSettingBase()
+ settings.allow_non_tensor = True
+ settings.skip_fusion = True
+
+ super().__init__(module, sample_input, operator_support, settings)
+
+ def _lower_model_to_backend(self, mod, inputs):
+ """
+ Lower a GraphModule `mod` to TensorRT with `inputs`.
+ """
+ mod = lower_mod_to_trt(mod, inputs)
+ return mod
+
+ def _find_culprit(self, mod, inputs):
+ """
+ This function serves the preview functionality in Splitter. When previewing
+ splitting result, if something wrong happens during lowering model to TensorRT
+ or running a TensorRT model, this function will be called to find any culprit
+ that is responsible for the error.
+ """
+ # Since we don't care about accuracy here, we pass in a dummy compare function.
+ minimizer = TensorRTMinimizer(mod, inputs, lambda a, b: (1, True))
+ minimizer.settings.traverse_method = "sequential"
+ minimizer.settings.find_all = True
+ culprits = minimizer.minimize()
+
+ if len(culprits) == 0:
+ reports = "Unable to find a culprit!\n"
+ else:
+ reports = "Found some problematic nodes:\n"
+ for node in culprits:
+ reports += f"{node.format_node()}\n"
+
+ return reports
+
+# Create a splitter which takes in traced ResNet18.
+splitter = TensorRTSplitter(traced_rn18, (x,), OpSupport())
+
+# node_support_preview() shows the details of node supporting information based
+# on the DummyOpSupport we created.
+#
+# In the output, we have supported node types
+# and unsupported node types. Nodes in the model with supported types will be
+# split into accelerator submodules while nodes with unsupported types will be
+# split into cpu submodules.
+splitter.node_support_preview()
+"""
+output:
+
+Supported node types in the model:
+torch.nn.modules.conv.Conv2d: ((torch.float32,), {})
+torch.nn.modules.batchnorm.BatchNorm2d: ((torch.float32,), {})
+torch.nn.modules.activation.ReLU: ((torch.float32,), {})
+torch.nn.modules.pooling.MaxPool2d: ((torch.float32,), {})
+_operator.add: ((torch.float32, torch.float32), {})
+torch.nn.modules.pooling.AdaptiveAvgPool2d: ((torch.float32,), {})
+torch.flatten: ((torch.float32,), {})
+
+Unsupported node types in the model:
+torch.nn.modules.linear.Linear: ((torch.float32,), {})
+"""
+
+# split_preview() shows the details of how the model looks like after split.
+# And for every accelerator module in the split model, it would run a check
+# by lowering and running the module. If any error is catched during the
+# checking process, it will try to find which nodes are causing the trouble
+# here with minimizer.
+#
+# Notice that after split, the model will have some submodules called either
+# `_run_on_acc_{}` or `_run_on_cpu_{}`. We have all the supported nodes in
+# `_run_on_acc_{}` modules and all other nodes in `_run_on_cpu_{}` modules.
+#
+# In the output, we can see it estimates the max qps based on PCIe bandwidth,
+# this is something we need to consider when lowering to acceleartors chips,
+# because the data will be flowing between cpu and accelerator which might not
+# matter in GPU case.
+splitter.split_preview()
+"""
+output:
+
+Before removing small acc subgraphs, total 2 subgraphs are created: 1 acc subgraphs and 1 cpu subgraphs.
+After removing small acc subgraphs, total 2 subgraphs are created: 1 acc subgraphs and 1 cpu subgraphs.
+_run_on_acc_0: 68 node(s)
+_run_on_cpu_1: 1 node(s)
+
+Processing acc submodule _run_on_acc_0
+Checking inputs...
+Checking outputs...
+Total input size in bytes is 3010560, total output size in bytes is 10240, theoretical max qps (bounds by PCIe bandwidth)
+for this submodule is 35665.85034013606.
+Lowering and running succeed!
+
+Theoretical max qps (bounds by PCIe bandwidth) for this model is 35665.85034013606, bottleneck is submodule _run_on_acc_0.
+"""
+
+# After split we have two submodules, one is `_run_on_acc_0` and one is `_run_on_cpu_1`.
+# We have only one op in `_run_on_cpu_1` which is a linear layer while all other ops are
+# in `_run_on_acc_0`.
+split_mod = splitter()
+print(split_mod.graph)
+"""
+output:
+
+graph():
+ %x : torch.Tensor [#users=1] = placeholder[target=x]
+ %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
+ %_run_on_cpu_1 : [#users=1] = call_module[target=_run_on_cpu_1](args = (%_run_on_acc_0,), kwargs = {})
+ return _run_on_cpu_1
+"""
+
+# We want to lower _run_on_acc_0 to TensorRT.
+split_mod._run_on_acc_0 = lower_mod_to_trt(split_mod._run_on_acc_0, (x,)) # type: ignore[arg-type]
+
+# Assert results are equal with the original model.
+rn18 = rn18.cuda()
+torch.testing.assert_allclose(split_mod(x), rn18(x))
+
+import time
+NITER = 100
+
+s = time.time()
+for _ in range(NITER):
+ split_mod(x)
+ torch.cuda.synchronize()
+print('trt time (ms/iter)', (time.time() - s) / NITER * 1000)
+"""
+output:
+
+trt time (ms/iter) 1.978142261505127
+"""
+
+s = time.time()
+for _ in range(NITER):
+ rn18(x)
+ torch.cuda.synchronize()
+print('stock PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000)
+"""
+output:
+
+stock PyTorch time (ms/iter) 3.8208484649658203
+"""
diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py
new file mode 100644
index 0000000..d8ce5ab
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/fx2trt.py
@@ -0,0 +1,239 @@
+from typing import List, NamedTuple, Iterable, Any, Optional
+
+import torch
+import torch.fx
+import tensorrt as trt
+import copy
+from torch.fx.experimental.normalize import NormalizeArgs
+
+
+# Borrowed from torch2trt
+def torch_dtype_to_trt(dtype):
+ if trt.__version__ >= '7.0' and dtype == torch.bool:
+ return trt.bool
+ elif dtype == torch.int8:
+ return trt.int8
+ elif dtype == torch.int32:
+ return trt.int32
+ elif dtype == torch.float16:
+ return trt.float16
+ elif dtype == torch.float32:
+ return trt.float32
+ else:
+ raise TypeError("%s is not supported by tensorrt" % dtype)
+
+
+def torch_dtype_from_trt(dtype):
+ if dtype == trt.int8:
+ return torch.int8
+ elif trt.__version__ >= '7.0' and dtype == trt.bool:
+ return torch.bool
+ elif dtype == trt.int32:
+ return torch.int32
+ elif dtype == trt.float16:
+ return torch.float16
+ elif dtype == trt.float32:
+ return torch.float32
+ else:
+ raise TypeError("%s is not supported by torch" % dtype)
+
+def torch_device_to_trt(device):
+ if device.type == torch.device("cuda").type:
+ return trt.TensorLocation.DEVICE
+ elif device.type == torch.device("cpu").type:
+ return trt.TensorLocation.HOST
+ else:
+ return TypeError("%s is not supported by tensorrt" % device)
+
+
+def torch_device_from_trt(device):
+ if device == trt.TensorLocation.DEVICE:
+ return torch.device("cuda")
+ elif device == trt.TensorLocation.HOST:
+ return torch.device("cpu")
+ else:
+ return TypeError("%s is not supported by torch" % device)
+
+
+class TRTModule(torch.nn.Module):
+ def __init__(self, engine=None, input_names=None, output_names=None):
+ super(TRTModule, self).__init__()
+ self._register_state_dict_hook(TRTModule._on_state_dict)
+ self.engine = engine
+ if self.engine is not None:
+ self.context = self.engine.create_execution_context()
+ self.input_names = input_names
+ self.output_names = output_names
+
+ def _on_state_dict(self, state_dict, prefix, local_metadata):
+ state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
+ state_dict[prefix + "input_names"] = self.input_names
+ state_dict[prefix + "output_names"] = self.output_names
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ engine_bytes = state_dict[prefix + "engine"]
+
+ with trt.Logger() as logger, trt.Runtime(logger) as runtime:
+ self.engine = runtime.deserialize_cuda_engine(engine_bytes)
+ self.context = self.engine.create_execution_context()
+
+ self.input_names = state_dict[prefix + "input_names"]
+ self.output_names = state_dict[prefix + "output_names"]
+
+ def forward(self, *inputs):
+ batch_size = inputs[0].shape[0]
+ bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names))
+
+ # create output tensors
+ outputs: List[torch.Tensor] = []
+ for i, output_name in enumerate(self.output_names):
+ idx: int = self.engine.get_binding_index(output_name)
+ dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
+ shape = (batch_size,) + tuple(self.engine.get_binding_shape(idx))
+ device = torch_device_from_trt(self.engine.get_location(idx))
+ output = torch.empty(size=shape, dtype=dtype, device=device)
+ outputs.append(output)
+ bindings[idx] = output.data_ptr()
+
+ for i, input_name in enumerate(self.input_names):
+ idx = self.engine.get_binding_index(input_name)
+ bindings[idx] = inputs[i].contiguous().data_ptr()
+
+ self.context.execute_async(
+ batch_size, bindings, torch.cuda.current_stream().cuda_stream
+ )
+
+ if len(outputs) == 1:
+ return outputs[0]
+
+ return tuple(outputs)
+
+ def enable_profiling(self):
+ if not self.context.profiler:
+ self.context.profiler = trt.Profiler()
+
+
+CONVERTERS = {}
+
+
+def tensorrt_converter(key):
+ def register_converter(converter):
+ CONVERTERS[key] = converter
+ return converter
+ return register_converter
+
+
+class InputTensorSpec(NamedTuple):
+ shape : torch.Size
+ dtype : torch.dtype
+
+ @classmethod
+ def from_tensor(cls, tensor: torch.Tensor):
+ return cls(tensor.shape, tensor.dtype)
+
+ @classmethod
+ def from_tensors(cls, tensors: Iterable[torch.Tensor]):
+ return [cls.from_tensor(t) for t in tensors]
+
+
+class TRTInterpreter(torch.fx.Interpreter):
+ def __init__(self, module : torch.fx.GraphModule, input_shapes : List[InputTensorSpec], logger_level=trt.Logger.WARNING):
+ # Preprocess the model
+ module = copy.copy(module)
+ module = module.cpu()
+ module = NormalizeArgs(module).transform()
+ super().__init__(module)
+
+ self.logger = trt.Logger(logger_level)
+ self.builder = trt.Builder(self.logger)
+ self.network = self.builder.create_network()
+
+ self.input_shape_itr = iter(input_shapes)
+
+ self._cur_node_name: Optional[str] = None
+
+ self._input_names: List[str] = []
+ self._output_names: List[str] = []
+
+ def run(
+ self,
+ *args,
+ max_batch_size=10,
+ max_workspace_size=1 << 25,
+ fp16_mode=False,
+ int8_mode=False,
+ strict_type_constraints=False
+ ):
+ super().run(*args)
+
+ self.builder.max_batch_size = max_batch_size
+ self.builder.max_workspace_size = max_workspace_size
+ self.builder.strict_type_constraints = strict_type_constraints
+ self.builder.fp16_mode = fp16_mode
+ self.builder.int8_mode = int8_mode
+
+ return self.builder.build_cuda_engine(self.network), self._input_names, self._output_names
+
+ def run_node(self, n):
+ self._cur_node_name = str(n)
+
+ try:
+ return super().run_node(n)
+ finally:
+ self._cur_node_metadata = None
+
+ def placeholder(self, target, args, kwargs):
+ shape, dtype = next(self.input_shape_itr)
+ self._input_names.append(target)
+ return self.network.add_input(name=target, shape=tuple(shape[1:]), dtype=torch_dtype_to_trt(dtype))
+
+ def call_module(self, target, args, kwargs):
+ assert isinstance(target, str)
+ submod = self.fetch_attr(target)
+
+ converter = CONVERTERS.get(type(submod))
+
+ if not converter:
+ raise RuntimeError(f'Conversion of module of type {type(submod)} not currently supported!')
+
+ return converter(self.network, submod, args, kwargs, self._cur_node_name)
+
+ def call_function(self, target, args, kwargs):
+ converter = CONVERTERS.get(target)
+
+ if not converter:
+ raise RuntimeError(f'Conversion of function {torch.typename(target)} not currently supported!')
+
+ return converter(self.network, target, args, kwargs, self._cur_node_name)
+
+ def call_method(self, target, args, kwargs):
+ assert isinstance(target, str)
+
+ converter = CONVERTERS.get(target)
+
+ if not converter:
+ raise RuntimeError(f'Conversion of method {target} not currently supported!')
+
+ return converter(self.network, target, args, kwargs, self._cur_node_name)
+
+ def output(self, target, args, kwargs):
+ assert len(args) == 1
+ outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
+ if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
+ raise RuntimeError('TensorRT requires all outputs to be Tensor!')
+
+ for i, output in enumerate(outputs):
+ # TODO: set location and dtype?
+ name = f'output{i}'
+ output.name = name
+ self.network.mark_output(output)
+ self._output_names.append(name)