| from typing import cast |
| |
| import torch |
| from torch.nn.modules.batchnorm import _BatchNorm |
| from torch.utils._mode_utils import no_dispatch |
| |
| |
| def _contains_batchnorm(module): |
| return any(isinstance(mod, _BatchNorm) for mod in module.modules()) |
| |
| |
| def _override_batchnorm_mixed_precision(module): |
| for mod in module.modules(): |
| if isinstance(mod, _BatchNorm): |
| mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] |
| |
| |
| def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: |
| """Returns if ``x`` and ``y`` share the same storage.""" |
| # NOTE: CPU and GPU tensors are ensured to have different data pointers. |
| return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr() |
| |
| |
| def _same_storage_as_data_ptr(x: torch.Tensor, data_ptr: int) -> bool: |
| return x._typed_storage()._data_ptr() == data_ptr |
| |
| |
| def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None: |
| with no_dispatch(): |
| tensor.record_stream(cast(torch._C.Stream, stream)) |