[export oncall] add some examples during oncall (#112445)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112445
Approved by: https://github.com/ydwu4
diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py
index 1c574f4..b9f68ad 100644
--- a/docs/source/scripts/exportdb/generate_example_rst.py
+++ b/docs/source/scripts/exportdb/generate_example_rst.py
@@ -81,6 +81,10 @@
output = f" {graph_output}"
except torchdynamo.exc.Unsupported as e:
output = " Unsupported: " + str(e).split("\n")[0]
+ except AssertionError as e:
+ output = " AssertionError: " + str(e).split("\n")[0]
+ except RuntimeError as e:
+ output = " RuntimeError: " + str(e).split("\n")[0]
doc_contents += output + "\n"
diff --git a/test/export/test_db.py b/test/export/test_db.py
index ee8f6d2..e321d8a 100644
--- a/test/export/test_db.py
+++ b/test/export/test_db.py
@@ -57,7 +57,7 @@
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
model = case.model
# pyre-ignore
- with self.assertRaises(torchdynamo.exc.Unsupported):
+ with self.assertRaises((torchdynamo.exc.Unsupported, AssertionError, RuntimeError)):
inputs = normalize_inputs(case.example_inputs)
exported_model = export(
model,
diff --git a/torch/_export/db/case.py b/torch/_export/db/case.py
index 1cc6eb2..caf7ca2 100644
--- a/torch/_export/db/case.py
+++ b/torch/_export/db/case.py
@@ -14,6 +14,8 @@
"escape-hatch": {},
"map": {},
"dynamic-value": {},
+ "operator": {},
+ "mutation": {},
},
"python": {
"assert": {},
@@ -23,6 +25,7 @@
"control-flow": {},
"data-structure": {},
"standard-library": {},
+ "object-model": {},
},
}
diff --git a/torch/_export/db/examples/model_attr_mutation.py b/torch/_export/db/examples/model_attr_mutation.py
new file mode 100644
index 0000000..b4d76cc
--- /dev/null
+++ b/torch/_export/db/examples/model_attr_mutation.py
@@ -0,0 +1,25 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+ example_inputs=(torch.ones(3, 2),),
+ tags={"python.object-model"},
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class ModelAttrMutation(torch.nn.Module):
+ """
+ Attribute mutation is not supported.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)]
+
+ def recreate_list(self):
+ return [torch.zeros(3, 2), torch.zeros(3, 2)]
+
+ def forward(self, x):
+ self.attr_list = self.recreate_list()
+ return x.sum() + self.attr_list[0].sum()
diff --git a/torch/_export/db/examples/optional_input.py b/torch/_export/db/examples/optional_input.py
new file mode 100644
index 0000000..4a06207
--- /dev/null
+++ b/torch/_export/db/examples/optional_input.py
@@ -0,0 +1,19 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+ example_inputs=(torch.randn(2, 3),),
+ tags={"python.object-model"},
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class OptionalInput(torch.nn.Module):
+ """
+ Tracing through optional input is not supported yet
+ """
+
+ def forward(self, x, y=torch.ones(2, 3)):
+ if y is not None:
+ return x + y
+ return x
diff --git a/torch/_export/db/examples/torch_sym_min.py b/torch/_export/db/examples/torch_sym_min.py
new file mode 100644
index 0000000..b9f4dd8
--- /dev/null
+++ b/torch/_export/db/examples/torch_sym_min.py
@@ -0,0 +1,17 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+ example_inputs=(torch.ones(3, 2),),
+ tags={"torch.operator"},
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class TorchSymMin(torch.nn.Module):
+ """
+ torch.sym_min operator is not supported in export.
+ """
+
+ def forward(self, x):
+ return x.sum() + torch.sym_min(x.size(0), 100)
diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py
new file mode 100644
index 0000000..56af08d
--- /dev/null
+++ b/torch/_export/db/examples/user_input_mutation.py
@@ -0,0 +1,18 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+ example_inputs=(torch.ones(3, 2),),
+ tags={"torch.mutation"},
+ support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class UserInputMutation(torch.nn.Module):
+ """
+ Can't directly mutate user input in forward
+ """
+
+ def forward(self, x):
+ x.mul_(2)
+ return x.cos()