[fx2trt] example for lowering model to trt with FX based tooling (#57298)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57298

Some of the code is borrowed from NVIDIA-AI-IOT/torch2trt https://github.com/NVIDIA-AI-IOT/torch2trt/tree/master/torch2trt.

Move fx2trt stuff to fx/experimental/fx2trt.

Add an example in fx/experimental/fx2trt/example/fx2trt_example.py that shows how we lower resnet18 to TensorRT using FX.

TODO: Include license from NVIDIA-AI-IOT/torch2trt

Test Plan: CI

Reviewed By: jackm321

Differential Revision: D28102144

fbshipit-source-id: 1a7b03e45b8ab3fcc355d097d73afeec2efc3328
diff --git a/torch/fx/experimental/fx2trt/__init__.py b/torch/fx/experimental/fx2trt/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/__init__.py
diff --git a/torch/fx/experimental/fx2trt/converter/__init__.py b/torch/fx/experimental/fx2trt/converter/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/converter/__init__.py
diff --git a/torch/fx/experimental/fx2trt/converter/vanilla_converter.py b/torch/fx/experimental/fx2trt/converter/vanilla_converter.py
new file mode 100644
index 0000000..af11070
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/converter/vanilla_converter.py
@@ -0,0 +1,237 @@
+import operator
+import torch
+import tensorrt as trt
+import numpy as np
+
+from torch.fx.experimental.fx2trt.fx2trt import tensorrt_converter, torch_dtype_to_trt
+
+
+def process_attr(submod: torch.nn.Module, name: str, size: int):
+    val = getattr(submod, name)
+    if not isinstance(val, tuple):
+        val = (val,) * size
+    return val
+
+
+def to_numpy(tensor):
+    if tensor.is_quantized:
+        tensor = tensor.dequantize()
+    return tensor.detach().contiguous().numpy()
+
+
+@tensorrt_converter(torch.nn.modules.conv.Conv2d)
+def torch_nn_modules_conv_Conv2d(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"Conv2d received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    kernel_size = process_attr(submod, "kernel_size", 2)
+    stride = process_attr(submod, "stride", 2)
+    padding = process_attr(submod, "padding", 2)
+    dilation = process_attr(submod, "dilation", 2)
+
+    kernel = to_numpy(submod.weight)
+
+    bias = trt.Weights(torch_dtype_to_trt(submod.weight.dtype))
+    if submod.bias is not None:
+        bias = to_numpy(submod.bias)
+
+    layer = network.add_convolution(
+        input=input_val,
+        num_output_maps=submod.out_channels,
+        kernel_shape=kernel_size,
+        kernel=kernel,
+        bias=bias,
+    )
+    layer.name = name
+    layer.stride = stride
+    layer.padding = padding
+    layer.dilation = dilation
+
+    if submod.groups is not None:
+        layer.num_groups = submod.groups
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.batchnorm.BatchNorm2d)
+def torch_nn_modules_batchnorm_BatchNorm2d(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"BatchNorm2d received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    scale = to_numpy(submod.weight) / np.sqrt(
+        to_numpy(submod.running_var) + submod.eps
+    )
+    bias = (
+        to_numpy(submod.bias)
+        - to_numpy(submod.running_mean) * scale
+    )
+    power = np.ones_like(scale)
+
+    layer = network.add_scale(input_val, trt.ScaleMode.CHANNEL, bias, scale, power)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.functional.relu)
+@tensorrt_converter(torch.nn.modules.activation.ReLU)
+def torch_nn_modules_activation_ReLU(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"ReLU received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    layer = network.add_activation(
+        input=input_val, type=trt.ActivationType.RELU)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.pooling.MaxPool2d)
+def torch_nn_modules_pooling_MaxPool2d(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"MaxPool2d received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    kernel_size = process_attr(submod, "kernel_size", 2)
+    stride = process_attr(submod, "stride", 2)
+    padding = process_attr(submod, "padding", 2)
+    ceil_mode = submod.ceil_mode
+
+    layer = network.add_pooling(
+        input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size)
+
+    layer.stride = stride
+    layer.padding = padding
+    layer.name = name
+
+    if ceil_mode:
+        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
+
+    return layer.get_output(0)
+
+
+@tensorrt_converter(operator.add)
+@tensorrt_converter(torch.add)
+def torch_add(network, target, args, kwargs, name):
+    if len(kwargs) != 0:
+        raise RuntimeError("`out` parameter on torch.add not supported!")
+
+    assert len(args) == 2
+    if not all(isinstance(arg, trt.tensorrt.ITensor) for arg in args):
+        raise RuntimeError("add() received an input that is not part of the TensorRT region!")
+
+    lhs_val, rhs_val = args
+
+    # TODO: broadcast
+    # https://github.com/NVIDIA-AI-IOT/torch2trt/blob/44977a94cb087fe521421802e9df12a5ac3ceb3f/torch2trt/torch2trt.py#L168
+
+    layer = network.add_elementwise(lhs_val, rhs_val, trt.ElementWiseOperation.SUM)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.pooling.AdaptiveAvgPool2d)
+def torch_nn_modules_pooling_AdaptiveAvgPool2d(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"AdaptiveAvgPool2d received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    output_size = process_attr(submod, "output_size", 2)
+    stride = (input_val.shape[-2] // output_size[-2], input_val.shape[-1] // output_size[-1])
+    kernel_size = stride
+    layer = network.add_pooling(
+        input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
+    layer.stride = stride
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.flatten)
+def torch_flatten(network, target, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"Flatten received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    # For trt shape we don"t have batch dim
+    start_dim = kwargs["start_dim"] - 1
+    end_dim = len(input_val.shape) if kwargs["end_dim"] == -1 else kwargs["end_dim"] - 1
+
+    assert start_dim >= 0, "Expect non negtive start_dim, this probably due to flatten batch dim."
+
+    new_shape = []
+    flatten_dim = 1
+    for i, dim in enumerate(input_val.shape):
+        if i < start_dim:
+            new_shape.append(dim)
+        elif i > end_dim:
+            new_shape.append(flatten_dim)
+            new_shape.append(dim)
+        else:
+            flatten_dim *= dim
+
+    if end_dim == len(input_val.shape):
+        new_shape.append(flatten_dim)
+
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(new_shape)
+    layer.name = name
+    return layer.get_output(0)
+
+
+@tensorrt_converter(torch.nn.modules.linear.Linear)
+def torch_nn_modules_linear_Linear(network, submod, args, kwargs, name):
+    # args/kwargs should have already been normalized to kwargs
+    assert len(args) == 0
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"Linear received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(input_val.shape) + (1, 1)
+    layer.name = f"{name}_pre_shuffle"
+
+    bias = trt.Weights(torch_dtype_to_trt(submod.weight.dtype))
+    if submod.bias is not None:
+        bias = to_numpy(submod.bias)
+
+    # add fully connected
+    layer = network.add_fully_connected(
+        input=layer.get_output(0),
+        num_outputs=submod.out_features,
+        kernel=to_numpy(submod.weight),
+        bias=bias
+    )
+    layer.name = f"{name}_linear"
+
+    # reshape back
+    layer = network.add_shuffle(layer.get_output(0))
+    layer.reshape_dims = tuple(input_val.shape[:-1]) + (submod.out_features,)
+    layer.name = f"{name}_post_shuffle"
+    return layer.get_output(0)
diff --git a/torch/fx/experimental/fx2trt/example/fx2trt_example.py b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
new file mode 100644
index 0000000..f9e306f
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
@@ -0,0 +1,264 @@
+from typing import Tuple, Dict, Callable, Any
+
+import torch
+import torch.fx
+import torchvision.models as models
+import torch.fx.experimental.fx2trt.converter.vanilla_converter
+import torch.fx.passes.splitter_base as splitter_base
+import torch.fx.passes.operator_support as op_support
+import torch.fx.passes.net_min_base as net_min_base
+from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule
+
+
+# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
+# model to TensorRT via FX with existing FX based tooling. The general lowering flow
+# would be like:
+#
+# 1. Use splitter to split the model if there're ops in the model that we don't want to
+#    lower to TensorRT for some reasons like the ops are not supported in TensorRT or
+#    running them on other backends provides better performance.
+# 2. Lower the model (or part of the model if splitter is used) to TensorRT via fx2trt.
+#
+# For this example, we use ResNet18 as example model and split out the linear layer to
+# not run on TensorRT just to demonstrate how the splitter works. At the end of this
+# example we did a benchmark for a model (named `split_mod`) with all the ops running
+# on TensorRT execpt linear layer running on PyTorch Cuda versus a model (named `rn18`)
+# fully on PyTorch Cuda.
+
+
+# Create ResNet18 `rn18` and inputs `x`
+rn18 = models.resnet18().eval().cuda()
+x = torch.randn(5, 3, 224, 224, device="cuda")
+
+# Trace the model with FX.
+traced_rn18 = torch.fx.symbolic_trace(rn18)
+
+
+def lower_mod_to_trt(mod: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
+    """
+    Helper function that given a GraphModule `mod` and its `inputs`, build a
+    TRTModule that runs the original `mod` on TensorRT.
+    """
+    interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
+    engine, input_names, output_names = interp.run(*inputs)
+    return TRTModule(engine, input_names, output_names)
+
+
+class OpSupport(op_support.OperatorSupport):
+    """
+    This class is used by splitter to determine which nodes are supported, i.e.
+    should be split to the accelerator part (TensorRT).
+    """
+    def is_node_supported(
+        self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
+    ):
+        """
+        Here we want linear layer to not run on TensorRT. Thus, we return
+        False for linear layer and True for all other ops.
+        """
+        target = op_support.get_node_target(submodules, node)
+
+        if target == "torch.nn.modules.linear.Linear":
+            return False
+
+        return True
+
+
+class TensorRTMinimizer(net_min_base._MinimizerBase):
+    """
+    Need to define a Minimizer class for TensorRT because it's used in Splitter.
+    """
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tuple[torch.Tensor],
+        compare_fn: Callable[[Any, Any], Tuple[float, bool]],
+        settings: net_min_base._MinimizerSettingBase = None,
+    ):
+        if settings is None:
+            settings = net_min_base._MinimizerSettingBase()
+
+        super().__init__(module, sample_input, compare_fn, settings)
+
+    def run_a(self, mod, inputs):
+        """
+        The output of this function serves as an reference.
+        """
+        mod.eval()
+        with torch.no_grad():
+            return mod(*inputs)
+
+    def run_b(self, mod, inputs):
+        """
+        Here we actually run mod on TensorRT return TensorRT result.
+        """
+        mod.eval()
+        try:
+            mod = lower_mod_to_trt(mod, inputs)
+            output = mod(*inputs)
+        except RuntimeError as e:
+            raise net_min_base.FxNetMinimizerRunFuncError(
+                f"Encounter an error when processing \n{mod.graph}\n {e}"
+            )
+        else:
+            return output
+
+
+# This in the future will be a global TensorRTSplitter and we don't need to create
+# it per example.
+class TensorRTSplitter(splitter_base._SplitterBase):
+    """
+    Splitter for TensorRT.
+    """
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tuple[torch.Tensor],
+        operator_support: op_support.OperatorSupport = None,
+        settings: splitter_base._SplitterSettingBase = None
+    ):
+        if not operator_support:
+            operator_support = op_support.OperatorSupport()
+
+        if not settings:
+            settings = splitter_base._SplitterSettingBase()
+            settings.allow_non_tensor = True
+            settings.skip_fusion = True
+
+        super().__init__(module, sample_input, operator_support, settings)
+
+    def _lower_model_to_backend(self, mod, inputs):
+        """
+        Lower a GraphModule `mod` to TensorRT with `inputs`.
+        """
+        mod = lower_mod_to_trt(mod, inputs)
+        return mod
+
+    def _find_culprit(self, mod, inputs):
+        """
+        This function serves the preview functionality in Splitter. When previewing
+        splitting result, if something wrong happens during lowering model to TensorRT
+        or running a TensorRT model, this function will be called to find any culprit
+        that is responsible for the error.
+        """
+        # Since we don't care about accuracy here, we pass in a dummy compare function.
+        minimizer = TensorRTMinimizer(mod, inputs, lambda a, b: (1, True))
+        minimizer.settings.traverse_method = "sequential"
+        minimizer.settings.find_all = True
+        culprits = minimizer.minimize()
+
+        if len(culprits) == 0:
+            reports = "Unable to find a culprit!\n"
+        else:
+            reports = "Found some problematic nodes:\n"
+            for node in culprits:
+                reports += f"{node.format_node()}\n"
+
+        return reports
+
+# Create a splitter which takes in traced ResNet18.
+splitter = TensorRTSplitter(traced_rn18, (x,), OpSupport())
+
+# node_support_preview() shows the details of node supporting information based
+# on the DummyOpSupport we created.
+#
+# In the output, we have supported node types
+# and unsupported node types. Nodes in the model with supported types will be
+# split into accelerator submodules while nodes with unsupported types will be
+# split into cpu submodules.
+splitter.node_support_preview()
+"""
+output:
+
+Supported node types in the model:
+torch.nn.modules.conv.Conv2d: ((torch.float32,), {})
+torch.nn.modules.batchnorm.BatchNorm2d: ((torch.float32,), {})
+torch.nn.modules.activation.ReLU: ((torch.float32,), {})
+torch.nn.modules.pooling.MaxPool2d: ((torch.float32,), {})
+_operator.add: ((torch.float32, torch.float32), {})
+torch.nn.modules.pooling.AdaptiveAvgPool2d: ((torch.float32,), {})
+torch.flatten: ((torch.float32,), {})
+
+Unsupported node types in the model:
+torch.nn.modules.linear.Linear: ((torch.float32,), {})
+"""
+
+# split_preview() shows the details of how the model looks like after split.
+# And for every accelerator module in the split model, it would run a check
+# by lowering and running the module. If any error is catched during the
+# checking process, it will try to find which nodes are causing the trouble
+# here with minimizer.
+#
+# Notice that after split, the model will have some submodules called either
+# `_run_on_acc_{}` or `_run_on_cpu_{}`. We have all the supported nodes in
+# `_run_on_acc_{}` modules and all other nodes in `_run_on_cpu_{}` modules.
+#
+# In the output, we can see it estimates the max qps based on PCIe bandwidth,
+# this is something we need to consider when lowering to acceleartors chips,
+# because the data will be flowing between cpu and accelerator which might not
+# matter in GPU case.
+splitter.split_preview()
+"""
+output:
+
+Before removing small acc subgraphs, total 2 subgraphs are created: 1 acc subgraphs and 1 cpu subgraphs.
+After removing small acc subgraphs, total 2 subgraphs are created: 1 acc subgraphs and 1 cpu subgraphs.
+_run_on_acc_0: 68 node(s)
+_run_on_cpu_1: 1 node(s)
+
+Processing acc submodule _run_on_acc_0
+Checking inputs...
+Checking outputs...
+Total input size in bytes is 3010560, total output size in bytes is 10240, theoretical max qps (bounds by PCIe bandwidth)
+for this submodule is 35665.85034013606.
+Lowering and running succeed!
+
+Theoretical max qps (bounds by PCIe bandwidth) for this model is 35665.85034013606, bottleneck is submodule _run_on_acc_0.
+"""
+
+# After split we have two submodules, one is `_run_on_acc_0` and one is `_run_on_cpu_1`.
+# We have only one op in `_run_on_cpu_1` which is a linear layer while all other ops are
+# in `_run_on_acc_0`.
+split_mod = splitter()
+print(split_mod.graph)
+"""
+output:
+
+graph():
+    %x : torch.Tensor [#users=1] = placeholder[target=x]
+    %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
+    %_run_on_cpu_1 : [#users=1] = call_module[target=_run_on_cpu_1](args = (%_run_on_acc_0,), kwargs = {})
+    return _run_on_cpu_1
+"""
+
+# We want to lower _run_on_acc_0 to TensorRT.
+split_mod._run_on_acc_0 = lower_mod_to_trt(split_mod._run_on_acc_0, (x,))  # type: ignore[arg-type]
+
+# Assert results are equal with the original model.
+rn18 = rn18.cuda()
+torch.testing.assert_allclose(split_mod(x), rn18(x))
+
+import time
+NITER = 100
+
+s = time.time()
+for _ in range(NITER):
+    split_mod(x)
+    torch.cuda.synchronize()
+print('trt time (ms/iter)', (time.time() - s) / NITER * 1000)
+"""
+output:
+
+trt time (ms/iter) 1.978142261505127
+"""
+
+s = time.time()
+for _ in range(NITER):
+    rn18(x)
+    torch.cuda.synchronize()
+print('stock PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000)
+"""
+output:
+
+stock PyTorch time (ms/iter) 3.8208484649658203
+"""
diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py
new file mode 100644
index 0000000..d8ce5ab
--- /dev/null
+++ b/torch/fx/experimental/fx2trt/fx2trt.py
@@ -0,0 +1,239 @@
+from typing import List, NamedTuple, Iterable, Any, Optional
+
+import torch
+import torch.fx
+import tensorrt as trt
+import copy
+from torch.fx.experimental.normalize import NormalizeArgs
+
+
+# Borrowed from torch2trt
+def torch_dtype_to_trt(dtype):
+    if trt.__version__ >= '7.0' and dtype == torch.bool:
+        return trt.bool
+    elif dtype == torch.int8:
+        return trt.int8
+    elif dtype == torch.int32:
+        return trt.int32
+    elif dtype == torch.float16:
+        return trt.float16
+    elif dtype == torch.float32:
+        return trt.float32
+    else:
+        raise TypeError("%s is not supported by tensorrt" % dtype)
+
+
+def torch_dtype_from_trt(dtype):
+    if dtype == trt.int8:
+        return torch.int8
+    elif trt.__version__ >= '7.0' and dtype == trt.bool:
+        return torch.bool
+    elif dtype == trt.int32:
+        return torch.int32
+    elif dtype == trt.float16:
+        return torch.float16
+    elif dtype == trt.float32:
+        return torch.float32
+    else:
+        raise TypeError("%s is not supported by torch" % dtype)
+
+def torch_device_to_trt(device):
+    if device.type == torch.device("cuda").type:
+        return trt.TensorLocation.DEVICE
+    elif device.type == torch.device("cpu").type:
+        return trt.TensorLocation.HOST
+    else:
+        return TypeError("%s is not supported by tensorrt" % device)
+
+
+def torch_device_from_trt(device):
+    if device == trt.TensorLocation.DEVICE:
+        return torch.device("cuda")
+    elif device == trt.TensorLocation.HOST:
+        return torch.device("cpu")
+    else:
+        return TypeError("%s is not supported by torch" % device)
+
+
+class TRTModule(torch.nn.Module):
+    def __init__(self, engine=None, input_names=None, output_names=None):
+        super(TRTModule, self).__init__()
+        self._register_state_dict_hook(TRTModule._on_state_dict)
+        self.engine = engine
+        if self.engine is not None:
+            self.context = self.engine.create_execution_context()
+        self.input_names = input_names
+        self.output_names = output_names
+
+    def _on_state_dict(self, state_dict, prefix, local_metadata):
+        state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
+        state_dict[prefix + "input_names"] = self.input_names
+        state_dict[prefix + "output_names"] = self.output_names
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        engine_bytes = state_dict[prefix + "engine"]
+
+        with trt.Logger() as logger, trt.Runtime(logger) as runtime:
+            self.engine = runtime.deserialize_cuda_engine(engine_bytes)
+            self.context = self.engine.create_execution_context()
+
+        self.input_names = state_dict[prefix + "input_names"]
+        self.output_names = state_dict[prefix + "output_names"]
+
+    def forward(self, *inputs):
+        batch_size = inputs[0].shape[0]
+        bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names))
+
+        # create output tensors
+        outputs: List[torch.Tensor] = []
+        for i, output_name in enumerate(self.output_names):
+            idx: int = self.engine.get_binding_index(output_name)
+            dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
+            shape = (batch_size,) + tuple(self.engine.get_binding_shape(idx))
+            device = torch_device_from_trt(self.engine.get_location(idx))
+            output = torch.empty(size=shape, dtype=dtype, device=device)
+            outputs.append(output)
+            bindings[idx] = output.data_ptr()
+
+        for i, input_name in enumerate(self.input_names):
+            idx = self.engine.get_binding_index(input_name)
+            bindings[idx] = inputs[i].contiguous().data_ptr()
+
+        self.context.execute_async(
+            batch_size, bindings, torch.cuda.current_stream().cuda_stream
+        )
+
+        if len(outputs) == 1:
+            return outputs[0]
+
+        return tuple(outputs)
+
+    def enable_profiling(self):
+        if not self.context.profiler:
+            self.context.profiler = trt.Profiler()
+
+
+CONVERTERS = {}
+
+
+def tensorrt_converter(key):
+    def register_converter(converter):
+        CONVERTERS[key] = converter
+        return converter
+    return register_converter
+
+
+class InputTensorSpec(NamedTuple):
+    shape : torch.Size
+    dtype : torch.dtype
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor):
+        return cls(tensor.shape, tensor.dtype)
+
+    @classmethod
+    def from_tensors(cls, tensors: Iterable[torch.Tensor]):
+        return [cls.from_tensor(t) for t in tensors]
+
+
+class TRTInterpreter(torch.fx.Interpreter):
+    def __init__(self, module : torch.fx.GraphModule, input_shapes : List[InputTensorSpec], logger_level=trt.Logger.WARNING):
+        # Preprocess the model
+        module = copy.copy(module)
+        module = module.cpu()
+        module = NormalizeArgs(module).transform()
+        super().__init__(module)
+
+        self.logger = trt.Logger(logger_level)
+        self.builder = trt.Builder(self.logger)
+        self.network = self.builder.create_network()
+
+        self.input_shape_itr = iter(input_shapes)
+
+        self._cur_node_name: Optional[str] = None
+
+        self._input_names: List[str] = []
+        self._output_names: List[str] = []
+
+    def run(
+        self,
+        *args,
+        max_batch_size=10,
+        max_workspace_size=1 << 25,
+        fp16_mode=False,
+        int8_mode=False,
+        strict_type_constraints=False
+    ):
+        super().run(*args)
+
+        self.builder.max_batch_size = max_batch_size
+        self.builder.max_workspace_size = max_workspace_size
+        self.builder.strict_type_constraints = strict_type_constraints
+        self.builder.fp16_mode = fp16_mode
+        self.builder.int8_mode = int8_mode
+
+        return self.builder.build_cuda_engine(self.network), self._input_names, self._output_names
+
+    def run_node(self, n):
+        self._cur_node_name = str(n)
+
+        try:
+            return super().run_node(n)
+        finally:
+            self._cur_node_metadata = None
+
+    def placeholder(self, target, args, kwargs):
+        shape, dtype = next(self.input_shape_itr)
+        self._input_names.append(target)
+        return self.network.add_input(name=target, shape=tuple(shape[1:]), dtype=torch_dtype_to_trt(dtype))
+
+    def call_module(self, target, args, kwargs):
+        assert isinstance(target, str)
+        submod = self.fetch_attr(target)
+
+        converter = CONVERTERS.get(type(submod))
+
+        if not converter:
+            raise RuntimeError(f'Conversion of module of type {type(submod)} not currently supported!')
+
+        return converter(self.network, submod, args, kwargs, self._cur_node_name)
+
+    def call_function(self, target, args, kwargs):
+        converter = CONVERTERS.get(target)
+
+        if not converter:
+            raise RuntimeError(f'Conversion of function {torch.typename(target)} not currently supported!')
+
+        return converter(self.network, target, args, kwargs, self._cur_node_name)
+
+    def call_method(self, target, args, kwargs):
+        assert isinstance(target, str)
+
+        converter = CONVERTERS.get(target)
+
+        if not converter:
+            raise RuntimeError(f'Conversion of method {target} not currently supported!')
+
+        return converter(self.network, target, args, kwargs, self._cur_node_name)
+
+    def output(self, target, args, kwargs):
+        assert len(args) == 1
+        outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
+        if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
+            raise RuntimeError('TensorRT requires all outputs to be Tensor!')
+
+        for i, output in enumerate(outputs):
+            # TODO: set location and dtype?
+            name = f'output{i}'
+            output.name = name
+            self.network.mark_output(output)
+            self._output_names.append(name)