blob: b0c88326881c7ddea5bdec3373932adb2100d121 [file] [log] [blame]
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import TestCase, run_tests, freeze_rng_state
class TestModule(TestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
precision = 1e-5
rel_tol = 1e-5
@modules(module_db)
def test_forward(self, device, dtype, module_info):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
with freeze_rng_state():
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
outputs = m(*args, **kwargs)
# === Compare outputs to a reference if one is specified. ===
# TODO: Handle precision
reference_fn = module_input.reference_fn
if reference_fn is not None:
ref_outputs = reference_fn(m, *args, **kwargs)
self.assertEqual(outputs, ref_outputs)
instantiate_device_type_tests(TestModule, globals())
if __name__ == '__main__':
run_tests()