pt2: make aot_eager backend handle basic float8 operations (#107783)

Summary:

Reland of https://github.com/pytorch/pytorch/pull/107642 with a fix for tests on Windows.

Makes aot_eager backend of torch.compile handle basic float8 operations.

This is useful for float8 training UX.

Test Plan:

```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107783
Approved by: https://github.com/albanD
diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py
index a33adb2..91e1a59 100644
--- a/test/quantization/core/experimental/test_float8.py
+++ b/test/quantization/core/experimental/test_float8.py
@@ -1,8 +1,15 @@
 # Owner(s): ["oncall: quantization"]
 
+import unittest
+
 import torch
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
-from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
+from torch.testing._internal.common_utils import (
+    IS_WINDOWS,
+    parametrize,
+    run_tests,
+    TestCase,
+)
 
 # Masks for float8 simulation
 
@@ -157,6 +164,18 @@
         mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
         self.assertEqual(mul8, mul8_simulated)
 
+    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet")
+    @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+    def test_pt2_traceable_aot_eager(self, dtype):
+        @torch.compile(backend="aot_eager", fullgraph=True)
+        def f(x):
+            x = x.to(dtype)
+            x = x.float()
+            return x
+
+        x = torch.randn(1).requires_grad_()
+        f(x).sum().backward()
+
 
 instantiate_device_type_tests(TestFloat8DtypeCPUOnly, globals(), only_for="cpu")
 
diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index f8dbae5..5989224 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -307,6 +307,8 @@
     sizes = {
         torch.complex64: 8,
         torch.complex128: 16,
+        torch.float8_e4m3fn: 1,
+        torch.float8_e5m2: 1,
         torch.float16: 2,
         torch.bfloat16: 2,
         torch.float32: 4,
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 33663ca..d0f2607 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -200,6 +200,8 @@
     torch.float64: 'f64',
     torch.float32: 'f32',
     torch.float16: 'f16',
+    torch.float8_e4m3fn: 'f8e4m3fn',
+    torch.float8_e5m2: 'f8e5m2',
     torch.complex32: 'c32',
     torch.complex64: 'c64',
     torch.complex128: 'c128',