| from torch.testing._internal.common_device_type import instantiate_device_type_tests |
| from torch.testing._internal.common_modules import module_db, modules |
| from torch.testing._internal.common_utils import TestCase, run_tests, freeze_rng_state |
| |
| |
| class TestModule(TestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| precision = 1e-5 |
| rel_tol = 1e-5 |
| |
| @modules(module_db) |
| def test_forward(self, device, dtype, module_info): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| with freeze_rng_state(): |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| |
| # === Do forward pass. === |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| outputs = m(*args, **kwargs) |
| |
| # === Compare outputs to a reference if one is specified. === |
| # TODO: Handle precision |
| reference_fn = module_input.reference_fn |
| if reference_fn is not None: |
| ref_outputs = reference_fn(m, *args, **kwargs) |
| self.assertEqual(outputs, ref_outputs) |
| |
| |
| instantiate_device_type_tests(TestModule, globals()) |
| |
| if __name__ == '__main__': |
| run_tests() |