blob: 5a53a628563961e39c121d306aa8fff9fadefc57 [file] [log] [blame]
from functools import partial
from typing import Any, Dict, Type
import torch
from torch.distributed.utils import _apply_to_tensors
from torch.utils._mode_utils import no_dispatch
def _override_module_mixed_precision(
root: torch.nn.Module,
module_cls_to_override: Type[torch.nn.Module],
wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
):
for mod in root.modules():
if isinstance(mod, module_cls_to_override):
mod._wrap_overrides = wrap_override_dict # type: ignore[assignment]
# TODO: We need to run this mixed precision ignored module in fp32,
# but ensure subsequent modules, that may possibly be running with
# mixed precision, still receive the appropriate precision inputs
# without user having to adjust mixed precision config too much.
# As a result, we attach pre and post forward hooks to up / down
# cast. We should revisit this design.
old_dtype = None
def cast_fn(dtype, x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
nonlocal old_dtype
old_dtype = x.dtype
return x.to(dtype)
def forward_pre_hook(module, args):
return _apply_to_tensors(partial(cast_fn, torch.float32), args)
def forward_post_hook(module, args, output):
nonlocal old_dtype
if old_dtype is not None:
return _apply_to_tensors(partial(cast_fn, old_dtype), output)
# We intentionally append both of these hooks so that they run after
# all other hooks.
mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
mod.register_forward_hook(forward_post_hook, prepend=False)
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""Returns if ``x`` and ``y`` share the same storage."""
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
def _same_storage_as_data_ptr(x: torch.Tensor, data_ptr: int) -> bool:
return x._typed_storage()._data_ptr() == data_ptr
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
# FIXME record_stream doesn't work with non-cuda tensors
if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
return
with no_dispatch():
tensor.record_stream(stream)