| from torch.testing._internal.common_utils import ( |
| TestCase as TorchTestCase, |
| from . import config, reset, utils |
| log = logging.getLogger(__name__) |
| from torch.testing._internal.common_utils import run_tests |
| or sys.version_info >= (3, 12) |
| if isinstance(needs, str): |
| if need == "cuda" and not torch.cuda.is_available(): |
| importlib.import_module(need) |
| class TestCase(TorchTestCase): |
| cls._exit_stack = contextlib.ExitStack() |
| cls._exit_stack.enter_context( |
| raise_on_ctx_manager_usage=True, |
| log_compilation_metrics=False, |
| self._prior_is_grad_enabled = torch.is_grad_enabled() |
| for k, v in utils.counters.items(): |
| print(k, v.most_common()) |
| if self._prior_is_grad_enabled is not torch.is_grad_enabled(): |
| log.warning("Running test changed grad mode") |
| torch.set_grad_enabled(self._prior_is_grad_enabled) |