pt2: make aot_eager backend handle basic float8 operations (#107783)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/107642 with a fix for tests on Windows.
Makes aot_eager backend of torch.compile handle basic float8 operations.
This is useful for float8 training UX.
Test Plan:
```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107783
Approved by: https://github.com/albanD
diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py
index a33adb2..91e1a59 100644
--- a/test/quantization/core/experimental/test_float8.py
+++ b/test/quantization/core/experimental/test_float8.py
@@ -1,8 +1,15 @@
# Owner(s): ["oncall: quantization"]
+import unittest
+
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
-from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
+from torch.testing._internal.common_utils import (
+ IS_WINDOWS,
+ parametrize,
+ run_tests,
+ TestCase,
+)
# Masks for float8 simulation
@@ -157,6 +164,18 @@
mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
self.assertEqual(mul8, mul8_simulated)
+ @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet")
+ @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+ def test_pt2_traceable_aot_eager(self, dtype):
+ @torch.compile(backend="aot_eager", fullgraph=True)
+ def f(x):
+ x = x.to(dtype)
+ x = x.float()
+ return x
+
+ x = torch.randn(1).requires_grad_()
+ f(x).sum().backward()
+
instantiate_device_type_tests(TestFloat8DtypeCPUOnly, globals(), only_for="cpu")
diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py
index f8dbae5..5989224 100644
--- a/torch/_functorch/partitioners.py
+++ b/torch/_functorch/partitioners.py
@@ -307,6 +307,8 @@
sizes = {
torch.complex64: 8,
torch.complex128: 16,
+ torch.float8_e4m3fn: 1,
+ torch.float8_e5m2: 1,
torch.float16: 2,
torch.bfloat16: 2,
torch.float32: 4,
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 33663ca..d0f2607 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -200,6 +200,8 @@
torch.float64: 'f64',
torch.float32: 'f32',
torch.float16: 'f16',
+ torch.float8_e4m3fn: 'f8e4m3fn',
+ torch.float8_e5m2: 'f8e5m2',
torch.complex32: 'c32',
torch.complex64: 'c64',
torch.complex128: 'c128',