[export] Don't save example_inputs for now. (#107978)
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107978
Approved by: https://github.com/angelayi
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index 5c87005..5a4d9b9 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -207,13 +207,6 @@
else:
self.assertEqual(orig, loaded)
- self.assertEqual(len(ep.original_traced_arguments), len(deserialized_ep.original_traced_arguments))
- for arg1, arg2 in zip(ep.original_traced_arguments, deserialized_ep.original_traced_arguments):
- if isinstance(arg1, torch.Tensor) and isinstance(arg2, torch.Tensor):
- self.assertTrue(torch.allclose(arg1, arg2))
- else:
- self.assertEqual(type(arg1), type(arg2))
-
def _check_graph_nodes(gm1, gm2, _check_meta=True):
# TODO: The _check_meta flag bypasses checking for
# source_fn/nn_module_stack as there is an issue with
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 0392d70..bc15dd2 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -433,7 +433,7 @@
range_constraints,
equality_constraints,
[ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()],
- args,
+ (args, {}),
)
exported_program = exported_program._transform(
diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py
index f516cfb..871ecc1 100644
--- a/torch/_export/serde/schema.py
+++ b/torch/_export/serde/schema.py
@@ -260,4 +260,4 @@
range_constraints: Dict[str, RangeConstraint]
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
schema_version: int
- original_traced_arguments: str
+ example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py
index c8a1adc..903ea45 100644
--- a/torch/_export/serde/serialize.py
+++ b/torch/_export/serde/serialize.py
@@ -809,9 +809,6 @@
)
serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints)
serialized_equality_constraints = serialize_equality_constraints(exported_program.equality_constraints)
- serialized_original_arguments = base64.b64encode(
- serialize_torch_artifact(exported_program.original_traced_arguments)
- ).decode('utf-8')
return (
ExportedProgram(
@@ -820,7 +817,7 @@
range_constraints=serialized_range_constraints,
equality_constraints=serialized_equality_constraints,
schema_version=SCHEMA_VERSION,
- original_traced_arguments=serialized_original_arguments,
+ example_inputs=None,
),
serialize_torch_artifact(exported_program.state_dict),
)
@@ -1367,9 +1364,6 @@
state_dict = deserialize_torch_artifact(serialized_state_dict)
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
- original_traced_arguments = deserialize_torch_artifact(
- base64.b64decode(serialized_exported_program.original_traced_arguments)
- )
exported_program = ep.ExportedProgram(
graph_module,
@@ -1380,7 +1374,7 @@
range_constraints,
equality_constraints,
module_call_graph,
- original_traced_arguments, # type: ignore[arg-type]
+ None, # type: ignore[arg-type]
)
return upgrader.upgrade(exported_program)
diff --git a/torch/export/__init__.py b/torch/export/__init__.py
index 641d5e9..ea1518e 100644
--- a/torch/export/__init__.py
+++ b/torch/export/__init__.py
@@ -239,7 +239,7 @@
range_constraints: Dict[sympy.Symbol, Any],
equality_constraints: List[Tuple[Any, Any]],
module_call_graph: List[ModuleCallEntry],
- original_traced_arguments: Tuple[Any, ...] = (),
+ example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
):
from torch._export.exported_program import (
_create_graph_module_for_export,
@@ -264,7 +264,7 @@
Tuple[InputDim, InputDim]
] = equality_constraints
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
- self._original_traced_arguments = original_traced_arguments
+ self._example_inputs = example_inputs
@property
@compatibility(is_backward_compatible=False)
@@ -308,8 +308,8 @@
@property
@compatibility(is_backward_compatible=False)
- def original_traced_arguments(self):
- return self._original_traced_arguments
+ def example_inputs(self):
+ return self._example_inputs
def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
@@ -531,7 +531,7 @@
_get_updated_range_constraints(transformed_gm),
copy.deepcopy(self.equality_constraints),
copy.deepcopy(self._module_call_graph),
- self.original_traced_arguments,
+ self.example_inputs,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)