| # Owner(s): ["module: functorch"] |
| |
| import functools |
| from torch.testing._internal.common_utils import run_tests, skipIfRocm |
| import test_aotdispatch |
| |
| |
| def make_functionalize_fn(fn): |
| @functools.wraps(fn) |
| def _fn(*args, **kwargs): |
| return fn(*args, **kwargs) |
| |
| return _fn |
| |
| |
| def make_functionalize_test(cls): |
| class FunctionalizeTest(cls): |
| pass |
| |
| FunctionalizeTest.__name__ = f"Functionalize{cls.__name__}" |
| |
| for name in dir(cls): |
| if name.startswith("test_"): |
| fn = getattr(cls, name) |
| if not callable(fn): |
| continue |
| |
| new_name = f"{name}_functionalize" |
| fn = make_functionalize_fn(fn) |
| fn.__name__ = new_name |
| setattr(FunctionalizeTest, name, None) |
| setattr(FunctionalizeTest, new_name, fn) |
| |
| # https://github.com/pytorch/pytorch/issues/96560 |
| return skipIfRocm(FunctionalizeTest) |
| |
| |
| FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_aotdispatch.TestAOTAutograd) |
| FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_aotdispatch.TestPartitioning) |
| |
| if __name__ == "__main__": |
| run_tests() |