| # Owner(s): ["oncall: export"] |
| # flake8: noqa |
| import copy |
| import io |
| import unittest |
| |
| import torch |
| import torch._dynamo as torchdynamo |
| import torch.utils._pytree as pytree |
| from torch._dynamo.test_case import TestCase |
| from torch.export import export, load, save |
| from torch.export._trace import _export |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| ops, |
| ) |
| from torch.testing._internal.common_utils import ( |
| IS_WINDOWS, |
| run_tests, |
| TestCase as TorchTestCase, |
| ) |
| from torch.testing._internal.hop_db import ( |
| hop_db, |
| hop_that_doesnt_have_opinfo_test_allowlist, |
| ) |
| |
| hop_tests = [] |
| |
| for op_info in hop_db: |
| op_info_hop_name = op_info.name |
| if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist: |
| continue |
| hop_tests.append(op_info) |
| |
| |
| class TestHOPGeneric(TestCase): |
| def test_all_hops_have_op_info(self): |
| from torch._ops import _higher_order_ops |
| |
| hops_that_have_op_info = set([k.name for k in hop_db]) |
| all_hops = _higher_order_ops.keys() |
| |
| missing_ops = [] |
| |
| for op in all_hops: |
| if ( |
| op not in hops_that_have_op_info |
| and op not in hop_that_doesnt_have_opinfo_test_allowlist |
| ): |
| missing_ops.append(op) |
| |
| self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}") |
| |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") |
| class TestHOP(TestCase): |
| def _compare(self, eager_model, export, args, kwargs): |
| eager_args = copy.deepcopy(args) |
| eager_kwargs = copy.deepcopy(kwargs) |
| export_args = copy.deepcopy(args) |
| export_kwargs = copy.deepcopy(kwargs) |
| |
| flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs)) |
| flat_loaded_outputs = pytree.tree_leaves( |
| export.module()(*export_args, **export_kwargs) |
| ) |
| |
| for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): |
| self.assertEqual(type(orig), type(loaded)) |
| self.assertEqual(orig, loaded) |
| |
| @ops(hop_tests, allowed_dtypes=(torch.float,)) |
| def test_aot_export(self, device, dtype, op): |
| class Foo(torch.nn.Module): |
| def forward(self, *args): |
| return op.op(*args) |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for inp in sample_inputs_itr: |
| model = Foo() |
| input = inp.input if isinstance(inp.input, tuple) else (inp.input,) |
| args = (*input, *inp.args) |
| kwargs = inp.kwargs |
| ep = export(model, args, kwargs) |
| self._compare(model, ep, args, kwargs) |
| |
| @ops(hop_tests, allowed_dtypes=(torch.float,)) |
| def test_pre_dispatch_export(self, device, dtype, op): |
| class Foo(torch.nn.Module): |
| def forward(self, *args): |
| return op.op(*args) |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for inp in sample_inputs_itr: |
| model = Foo() |
| input = inp.input if isinstance(inp.input, tuple) else (inp.input,) |
| args = (*input, *inp.args) |
| kwargs = inp.kwargs |
| ep = _export(model, args, kwargs, pre_dispatch=True) |
| self._compare(model, ep, args, kwargs) |
| |
| @ops(hop_tests, allowed_dtypes=(torch.float,)) |
| def test_retrace_export(self, device, dtype, op): |
| class Foo(torch.nn.Module): |
| def forward(self, *args): |
| return op.op(*args) |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for inp in sample_inputs_itr: |
| model = Foo() |
| input = inp.input if isinstance(inp.input, tuple) else (inp.input,) |
| args = (*input, *inp.args) |
| kwargs = inp.kwargs |
| ep = _export(model, args, kwargs, pre_dispatch=True) |
| ep = ep.run_decompositions() |
| self._compare(model, ep, args, kwargs) |
| |
| @ops(hop_tests, allowed_dtypes=(torch.float,)) |
| def test_serialize_export(self, device, dtype, op): |
| class Foo(torch.nn.Module): |
| def forward(self, *args): |
| return op.op(*args) |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for inp in sample_inputs_itr: |
| model = Foo() |
| input = inp.input if isinstance(inp.input, tuple) else (inp.input,) |
| args = (*input, *inp.args) |
| kwargs = inp.kwargs |
| ep = _export(model, args, kwargs, pre_dispatch=True) |
| ep = ep.run_decompositions() |
| buffer = io.BytesIO() |
| save(ep, buffer) |
| buffer.seek(0) |
| ep = load(buffer) |
| self._compare(model, ep, args, kwargs) |
| |
| |
| instantiate_device_type_tests(TestHOP, globals()) |
| |
| if __name__ == "__main__": |
| run_tests() |