[export oncall] add some examples during oncall (#112445)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112445
Approved by: https://github.com/ydwu4
diff --git a/docs/source/scripts/exportdb/generate_example_rst.py b/docs/source/scripts/exportdb/generate_example_rst.py
index 1c574f4..b9f68ad 100644
--- a/docs/source/scripts/exportdb/generate_example_rst.py
+++ b/docs/source/scripts/exportdb/generate_example_rst.py
@@ -81,6 +81,10 @@
         output = f"    {graph_output}"
     except torchdynamo.exc.Unsupported as e:
         output = "    Unsupported: " + str(e).split("\n")[0]
+    except AssertionError as e:
+        output = "    AssertionError: " + str(e).split("\n")[0]
+    except RuntimeError as e:
+        output = "    RuntimeError: " + str(e).split("\n")[0]
 
     doc_contents += output + "\n"
 
diff --git a/test/export/test_db.py b/test/export/test_db.py
index ee8f6d2..e321d8a 100644
--- a/test/export/test_db.py
+++ b/test/export/test_db.py
@@ -57,7 +57,7 @@
     def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
         model = case.model
         # pyre-ignore
-        with self.assertRaises(torchdynamo.exc.Unsupported):
+        with self.assertRaises((torchdynamo.exc.Unsupported, AssertionError, RuntimeError)):
             inputs = normalize_inputs(case.example_inputs)
             exported_model = export(
                 model,
diff --git a/torch/_export/db/case.py b/torch/_export/db/case.py
index 1cc6eb2..caf7ca2 100644
--- a/torch/_export/db/case.py
+++ b/torch/_export/db/case.py
@@ -14,6 +14,8 @@
         "escape-hatch": {},
         "map": {},
         "dynamic-value": {},
+        "operator": {},
+        "mutation": {},
     },
     "python": {
         "assert": {},
@@ -23,6 +25,7 @@
         "control-flow": {},
         "data-structure": {},
         "standard-library": {},
+        "object-model": {},
     },
 }
 
diff --git a/torch/_export/db/examples/model_attr_mutation.py b/torch/_export/db/examples/model_attr_mutation.py
new file mode 100644
index 0000000..b4d76cc
--- /dev/null
+++ b/torch/_export/db/examples/model_attr_mutation.py
@@ -0,0 +1,25 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+    example_inputs=(torch.ones(3, 2),),
+    tags={"python.object-model"},
+    support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class ModelAttrMutation(torch.nn.Module):
+    """
+    Attribute mutation is not supported.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)]
+
+    def recreate_list(self):
+        return [torch.zeros(3, 2), torch.zeros(3, 2)]
+
+    def forward(self, x):
+        self.attr_list = self.recreate_list()
+        return x.sum() + self.attr_list[0].sum()
diff --git a/torch/_export/db/examples/optional_input.py b/torch/_export/db/examples/optional_input.py
new file mode 100644
index 0000000..4a06207
--- /dev/null
+++ b/torch/_export/db/examples/optional_input.py
@@ -0,0 +1,19 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+    example_inputs=(torch.randn(2, 3),),
+    tags={"python.object-model"},
+    support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class OptionalInput(torch.nn.Module):
+    """
+    Tracing through optional input is not supported yet
+    """
+
+    def forward(self, x, y=torch.ones(2, 3)):
+        if y is not None:
+            return x + y
+        return x
diff --git a/torch/_export/db/examples/torch_sym_min.py b/torch/_export/db/examples/torch_sym_min.py
new file mode 100644
index 0000000..b9f4dd8
--- /dev/null
+++ b/torch/_export/db/examples/torch_sym_min.py
@@ -0,0 +1,17 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+    example_inputs=(torch.ones(3, 2),),
+    tags={"torch.operator"},
+    support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class TorchSymMin(torch.nn.Module):
+    """
+    torch.sym_min operator is not supported in export.
+    """
+
+    def forward(self, x):
+        return x.sum() + torch.sym_min(x.size(0), 100)
diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py
new file mode 100644
index 0000000..56af08d
--- /dev/null
+++ b/torch/_export/db/examples/user_input_mutation.py
@@ -0,0 +1,18 @@
+import torch
+
+from torch._export.db.case import export_case, SupportLevel
+
+
+@export_case(
+    example_inputs=(torch.ones(3, 2),),
+    tags={"torch.mutation"},
+    support_level=SupportLevel.NOT_SUPPORTED_YET,
+)
+class UserInputMutation(torch.nn.Module):
+    """
+    Can't directly mutate user input in forward
+    """
+
+    def forward(self, x):
+        x.mul_(2)
+        return x.cos()