fx2trt example: run all submodules (#66590)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66590
Updated fx2trt example to run all submodules
Added assertion to make sure outputs from lowered and regular models matches
Test Plan: buck run mode/dev-nosan caffe2:fx2trt_example
Reviewed By: 842974287
Differential Revision: D31592985
fbshipit-source-id: 45ce0b33e957f16b3729d3ecde706331c29d7214
diff --git a/torch/fx/experimental/fx2trt/example/fx2trt_example.py b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
index 38fa621..e0b8675 100644
--- a/torch/fx/experimental/fx2trt/example/fx2trt_example.py
+++ b/torch/fx/experimental/fx2trt/example/fx2trt_example.py
@@ -95,6 +95,13 @@
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
engine, input_names, output_names = interp.run()
trt_mod = TRTModule(engine, input_names, output_names)
+split_mod._run_on_acc_0 = trt_mod
cuda_inputs = [input.cuda() for input in inputs]
-trt_mod(*cuda_inputs)
+split_mod.cuda()
+lowered_model_output = split_mod(*cuda_inputs)
+
+# we make sure the results match
+model.cuda()
+regular_model_output = model(*cuda_inputs)
+torch.testing.assert_close(lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2)