blob: 32154d28a184d66fffe651833bfa28ee98b1a536 [file] [log] [blame]
# Owner(s): ["module: functorch"]
import functools
from torch.testing._internal.common_utils import run_tests, skipIfRocm
import test_aotdispatch
def make_functionalize_fn(fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
return fn(*args, **kwargs)
return _fn
def make_functionalize_test(cls):
class FunctionalizeTest(cls):
pass
FunctionalizeTest.__name__ = f"Functionalize{cls.__name__}"
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
continue
new_name = f"{name}_functionalize"
fn = make_functionalize_fn(fn)
fn.__name__ = new_name
setattr(FunctionalizeTest, name, None)
setattr(FunctionalizeTest, new_name, fn)
# https://github.com/pytorch/pytorch/issues/96560
return skipIfRocm(FunctionalizeTest)
FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_aotdispatch.TestAOTAutograd)
FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_aotdispatch.TestPartitioning)
if __name__ == "__main__":
run_tests()