| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| from functorch import make_fx |
| from torch._dynamo import debug_utils |
| from torch._dynamo.test_case import TestCase |
| |
| |
| class TestDebugUtils(TestCase): |
| def test_cast_model_to_fp64_dtype_args(self): |
| # Test that dtype arguments are converted to fp64 |
| |
| def fn(x): |
| return ( |
| torch.ops.prims.convert_element_type(x, torch.float16), |
| x.to(torch.float16), |
| torch.full(x.shape, 2, dtype=torch.float32, device=x.device), |
| x.new_empty(x.shape), |
| ) |
| |
| x = torch.randn(32, device="cpu") |
| decomps = torch._decomp.core_aten_decompositions() |
| fx = make_fx(fn, decomposition_table=decomps)(x) |
| |
| self.assertExpectedInline( |
| fx.code.lstrip(), |
| """\ |
| def forward(self, x_1): |
| convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16) |
| _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None |
| full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) |
| empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) |
| return (convert_element_type, _to_copy, full, empty) |
| """, # NOQA: B950 |
| ) |
| |
| fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) |
| self.assertEqual(fp64_examples, (x.to(torch.float64),)) |
| |
| self.assertExpectedInline( |
| fx.code.lstrip(), |
| """\ |
| def forward(self, x_1): |
| convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64) |
| _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None |
| full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) |
| empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) |
| return (convert_element_type, _to_copy, full, empty) |
| """, # NOQA: B950 |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |