|  | # Owner(s): ["module: cpp"] | 
|  |  | 
|  | import torch | 
|  | # NN tests use double as the default dtype | 
|  | torch.set_default_dtype(torch.double) | 
|  |  | 
|  | import os | 
|  |  | 
|  | import torch.testing._internal.common_utils as common | 
|  | import torch.testing._internal.common_nn as common_nn | 
|  | from cpp_api_parity.parity_table_parser import parse_parity_tracker_table | 
|  | from cpp_api_parity.utils import is_torch_nn_functional_test | 
|  | from cpp_api_parity import module_impl_check, functional_impl_check, sample_module, sample_functional | 
|  |  | 
|  | # NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose) | 
|  | PRINT_CPP_SOURCE = False | 
|  |  | 
|  | devices = ['cpu', 'cuda'] | 
|  |  | 
|  | PARITY_TABLE_PATH = os.path.join(os.path.dirname(__file__), 'cpp_api_parity', 'parity-tracker.md') | 
|  |  | 
|  | parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH) | 
|  |  | 
|  | class TestCppApiParity(common.TestCase): | 
|  | module_test_params_map = {} | 
|  | functional_test_params_map = {} | 
|  |  | 
|  | expected_test_params_dicts = [] | 
|  |  | 
|  | if not common.IS_ARM64: | 
|  | for test_params_dicts, test_instance_class in [ | 
|  | (sample_module.module_tests, common_nn.NewModuleTest), | 
|  | (sample_functional.functional_tests, common_nn.NewModuleTest), | 
|  | (common_nn.module_tests, common_nn.NewModuleTest), | 
|  | (common_nn.new_module_tests, common_nn.NewModuleTest), | 
|  | (common_nn.criterion_tests, common_nn.CriterionTest), | 
|  | ]: | 
|  | for test_params_dict in test_params_dicts: | 
|  | if test_params_dict.get('test_cpp_api_parity', True): | 
|  | if is_torch_nn_functional_test(test_params_dict): | 
|  | functional_impl_check.write_test_to_test_class( | 
|  | TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices) | 
|  | else: | 
|  | module_impl_check.write_test_to_test_class( | 
|  | TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices) | 
|  | expected_test_params_dicts.append(test_params_dict) | 
|  |  | 
|  | # Assert that all NN module/functional test dicts appear in the parity test | 
|  | assert len([name for name in TestCppApiParity.__dict__ if 'test_torch_nn_' in name]) == \ | 
|  | len(expected_test_params_dicts) * len(devices) | 
|  |  | 
|  | # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. | 
|  | # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) | 
|  | assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == 4 | 
|  | # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) | 
|  | assert len([name for name in TestCppApiParity.__dict__ if 'sample_functional' in name]) == 4 | 
|  |  | 
|  | module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) | 
|  | functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | common.run_tests() |