[export] Don't save example_inputs for now. (#107978)

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107978
Approved by: https://github.com/angelayi
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index 5c87005..5a4d9b9 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -207,13 +207,6 @@
             else:
                 self.assertEqual(orig, loaded)
 
-        self.assertEqual(len(ep.original_traced_arguments), len(deserialized_ep.original_traced_arguments))
-        for arg1, arg2 in zip(ep.original_traced_arguments, deserialized_ep.original_traced_arguments):
-            if isinstance(arg1, torch.Tensor) and isinstance(arg2, torch.Tensor):
-                self.assertTrue(torch.allclose(arg1, arg2))
-            else:
-                self.assertEqual(type(arg1), type(arg2))
-
         def _check_graph_nodes(gm1, gm2, _check_meta=True):
             # TODO: The _check_meta flag bypasses checking for
             # source_fn/nn_module_stack as there is an issue with
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 0392d70..bc15dd2 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -433,7 +433,7 @@
         range_constraints,
         equality_constraints,
         [ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()],
-        args,
+        (args, {}),
     )
 
     exported_program = exported_program._transform(
diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py
index f516cfb..871ecc1 100644
--- a/torch/_export/serde/schema.py
+++ b/torch/_export/serde/schema.py
@@ -260,4 +260,4 @@
     range_constraints: Dict[str, RangeConstraint]
     equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
     schema_version: int
-    original_traced_arguments: str
+    example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py
index c8a1adc..903ea45 100644
--- a/torch/_export/serde/serialize.py
+++ b/torch/_export/serde/serialize.py
@@ -809,9 +809,6 @@
         )
         serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints)
         serialized_equality_constraints = serialize_equality_constraints(exported_program.equality_constraints)
-        serialized_original_arguments = base64.b64encode(
-            serialize_torch_artifact(exported_program.original_traced_arguments)
-        ).decode('utf-8')
 
         return (
             ExportedProgram(
@@ -820,7 +817,7 @@
                 range_constraints=serialized_range_constraints,
                 equality_constraints=serialized_equality_constraints,
                 schema_version=SCHEMA_VERSION,
-                original_traced_arguments=serialized_original_arguments,
+                example_inputs=None,
             ),
             serialize_torch_artifact(exported_program.state_dict),
         )
@@ -1367,9 +1364,6 @@
 
         state_dict = deserialize_torch_artifact(serialized_state_dict)
         equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
-        original_traced_arguments = deserialize_torch_artifact(
-            base64.b64decode(serialized_exported_program.original_traced_arguments)
-        )
 
         exported_program = ep.ExportedProgram(
             graph_module,
@@ -1380,7 +1374,7 @@
             range_constraints,
             equality_constraints,
             module_call_graph,
-            original_traced_arguments,  # type: ignore[arg-type]
+            None,  # type: ignore[arg-type]
         )
         return upgrader.upgrade(exported_program)
 
diff --git a/torch/export/__init__.py b/torch/export/__init__.py
index 641d5e9..ea1518e 100644
--- a/torch/export/__init__.py
+++ b/torch/export/__init__.py
@@ -239,7 +239,7 @@
         range_constraints: Dict[sympy.Symbol, Any],
         equality_constraints: List[Tuple[Any, Any]],
         module_call_graph: List[ModuleCallEntry],
-        original_traced_arguments: Tuple[Any, ...] = (),
+        example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
     ):
         from torch._export.exported_program import (
             _create_graph_module_for_export,
@@ -264,7 +264,7 @@
             Tuple[InputDim, InputDim]
         ] = equality_constraints
         self._module_call_graph: List[ModuleCallEntry] = module_call_graph
-        self._original_traced_arguments = original_traced_arguments
+        self._example_inputs = example_inputs
 
     @property
     @compatibility(is_backward_compatible=False)
@@ -308,8 +308,8 @@
 
     @property
     @compatibility(is_backward_compatible=False)
-    def original_traced_arguments(self):
-        return self._original_traced_arguments
+    def example_inputs(self):
+        return self._example_inputs
 
     def __call__(self, *args: Any, **kwargs: Any) -> Any:
         import torch._export.error as error
@@ -531,7 +531,7 @@
             _get_updated_range_constraints(transformed_gm),
             copy.deepcopy(self.equality_constraints),
             copy.deepcopy(self._module_call_graph),
-            self.original_traced_arguments,
+            self.example_inputs,
         )
         transformed_ep.graph_module.meta.update(self.graph_module.meta)
         transformed_ep.graph_module.meta.update(res.graph_module.meta)