blob: 31b18311e3d91205f942fcea589ac293e8f3fd3e [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.test_case import TestCase
class TestDebugUtils(TestCase):
def test_cast_model_to_fp64_dtype_args(self):
# Test that dtype arguments are converted to fp64
def fn(x):
return (
torch.ops.prims.convert_element_type(x, torch.float16),
x.to(torch.float16),
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
x.new_empty(x.shape),
)
x = torch.randn(32, device="cpu")
decomps = torch._decomp.core_aten_decompositions()
fx = make_fx(fn, decomposition_table=decomps)(x)
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()