Add ONNXProgram.__call__ API to run model with ONNX Runtime (#113495)
Currently the user can use torch.onnx.dynamo_export to export the model.
to ONNX.
```python
import torch
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0
onnx_program = torch.onnx.dynamo_export(
Model(),
torch.randn(1, 1, 2, dtype=torch.float),
)
```
The next step would be instantiating a ONNX runtime to execute it.
```python
import onnxruntime # type: ignore[import]
onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or {}
providers = options.get("providers", onnxruntime.get_available_providers())
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)
onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}
return ort_session.run(None, onnxruntime_input)
```
This PR provides the `ONNXProgram.__call__` method as facilitator to use ONNX Runtime under the hood, similar to how `torch.export.ExportedProgram.__call__` which allows the underlying `torch.fx.GraphModule` to be executed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113495
Approved by: https://github.com/titaiwangms
diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst
index a156c51..09a09bc 100644
--- a/docs/source/onnx_dynamo.rst
+++ b/docs/source/onnx_dynamo.rst
@@ -146,6 +146,9 @@
.. autoclass:: torch.onnx.ONNXProgramSerializer
:members:
+.. autoclass:: torch.onnx.ONNXRuntimeOptions
+ :members:
+
.. autoclass:: torch.onnx.InvalidExportOptionsError
:members:
diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py
index bcc8aa7..2892a23 100644
--- a/test/onnx/onnx_test_common.py
+++ b/test/onnx/onnx_test_common.py
@@ -439,15 +439,11 @@
ref_input_args = input_args
ref_input_kwargs = input_kwargs
- # Format original model inputs into the format expected by exported ONNX model.
- onnx_format_args = onnx_program.adapt_torch_inputs_to_onnx(
- *input_args, **input_kwargs
- )
-
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
- ort_outputs = run_ort(onnx_program, onnx_format_args)
+
+ ort_outputs = onnx_program(*input_args, **input_kwargs)
if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
index 728430c..26fa6f2 100644
--- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py
+++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
@@ -896,6 +896,21 @@
create_pytorch_only_extra_kwargs,
)
+ def test_execute_model_with___call__(self):
+ class Model(torch.nn.Module):
+ def forward(self, x):
+ return x + 1.0
+
+ input_x = torch.randn(1, 1, 2, dtype=torch.float)
+ onnx_program = torch.onnx.dynamo_export(
+ Model(),
+ input_x,
+ )
+
+ # The other tests use ONNXProgram.__call__ indirectly and check for output equality
+ # This test aims to ensure ONNXProgram.__call__ API runs successfully despite internal test infra code
+ _ = onnx_program(input_x)
+
def test_exported_program_as_input(self):
class Model(torch.nn.Module):
def forward(self, x):
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index e50dfb3..ad3af09 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -48,6 +48,7 @@
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
+ ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
@@ -103,6 +104,7 @@
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
+ "ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
@@ -118,6 +120,7 @@
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXProgramSerializer.__module__ = "torch.onnx"
+ONNXRuntimeOptions.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py
index fbda341..807ef52 100644
--- a/torch/onnx/_internal/exporter.py
+++ b/torch/onnx/_internal/exporter.py
@@ -1,5 +1,6 @@
-# necessary to surface onnx.ModelProto through ONNXProgram:
-from __future__ import annotations
+from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions)
+ annotations,
+)
import abc
@@ -52,6 +53,7 @@
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import onnx
+ import onnxruntime # type: ignore[import]
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
registration as torchlib_registry,
@@ -602,6 +604,41 @@
)
+class ONNXRuntimeOptions:
+ """Options to influence the execution of the ONNX model through ONNX Runtime.
+
+ Attributes:
+ session_options: ONNX Runtime session options.
+ execution_providers: ONNX Runtime execution providers to use during model execution.
+ execution_provider_options: ONNX Runtime execution provider options.
+ """
+
+ session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None
+ """ONNX Runtime session options."""
+
+ execution_providers: Optional[
+ Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
+ ] = None
+ """ONNX Runtime execution providers to use during model execution."""
+
+ execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None
+ """ONNX Runtime execution provider options."""
+
+ @_beartype.beartype
+ def __init__(
+ self,
+ *,
+ session_options: Optional[Sequence["onnxruntime.SessionOptions"]] = None,
+ execution_providers: Optional[
+ Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]
+ ] = None,
+ execution_provider_options: Optional[Sequence[Dict[Any, Any]]] = None,
+ ):
+ self.session_options = session_options
+ self.execution_providers = execution_providers
+ self.execution_provider_options = execution_provider_options
+
+
class ONNXProgram:
"""An in-memory representation of a PyTorch model that has been exported to ONNX.
@@ -643,6 +680,34 @@
self._fake_context = fake_context
self._export_exception = export_exception
+ def __call__(
+ self, *args: Any, options: Optional[ONNXRuntimeOptions] = None, **kwargs: Any
+ ) -> Any:
+ """Runs the ONNX model using ONNX Runtime
+
+ Args:
+ args: The positional inputs to the model.
+ kwargs: The keyword inputs to the model.
+ options: The options to use for running the model with ONNX Runtime.
+
+ Returns:
+ The model output as computed by ONNX Runtime
+ """
+ import onnxruntime # type: ignore[import]
+
+ onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
+ options = options or ONNXRuntimeOptions()
+ providers = options.execution_providers or onnxruntime.get_available_providers()
+ onnx_model = self.model_proto.SerializeToString()
+ ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
+
+ onnxruntime_input = {
+ k.name: v.numpy(force=True)
+ for k, v in zip(ort_session.get_inputs(), onnx_input)
+ }
+
+ return ort_session.run(None, onnxruntime_input)
+
@property
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
@@ -1416,6 +1481,7 @@
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
+ "ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",