| # Owner(s): ["module: nvfuser"] | |
| import torch | |
| from torch.testing._internal.common_utils import set_default_dtype | |
| try: | |
| from _nvfuser.test_torchscript import * # noqa: F403,F401 | |
| except ImportError: | |
| def run_tests(): | |
| return | |
| pass | |
| if __name__ == '__main__': | |
| # TODO: Update nvfuser to work with float default dtype | |
| with set_default_dtype(torch.double): | |
| run_tests() |