|  | import torch | 
|  | import functools | 
|  | import warnings | 
|  |  | 
|  | from typing import Any, Optional | 
|  | from torch.types import _dtype | 
|  |  | 
|  | def autocast_decorator(autocast_instance, func): | 
|  | @functools.wraps(func) | 
|  | def decorate_autocast(*args, **kwargs): | 
|  | with autocast_instance: | 
|  | return func(*args, **kwargs) | 
|  | decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode'  # type: ignore[attr-defined] | 
|  | return decorate_autocast | 
|  |  | 
|  | class autocast(object): | 
|  | r""" | 
|  | Instances of :class:`autocast` serve as context managers or decorators that | 
|  | allow regions of your script to run in mixed precision. | 
|  |  | 
|  | In these regions, ops run in an op-specific dtype chosen by autocast | 
|  | to improve performance while maintaining accuracy. | 
|  | See the :ref:`Autocast Op Reference<autocast-op-reference>` for details. | 
|  |  | 
|  | When entering an autocast-enabled region, Tensors may be any type. | 
|  | You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. | 
|  |  | 
|  | :class:`autocast` should wrap only the forward pass(es) of your network, including the loss | 
|  | computation(s).  Backward passes under autocast are not recommended. | 
|  | Backward ops run in the same type that autocast used for corresponding forward ops. | 
|  |  | 
|  | Example for CUDA Devices:: | 
|  |  | 
|  | # Creates model and optimizer in default precision | 
|  | model = Net().cuda() | 
|  | optimizer = optim.SGD(model.parameters(), ...) | 
|  |  | 
|  | for input, target in data: | 
|  | optimizer.zero_grad() | 
|  |  | 
|  | # Enables autocasting for the forward pass (model + loss) | 
|  | with autocast(): | 
|  | output = model(input) | 
|  | loss = loss_fn(output, target) | 
|  |  | 
|  | # Exits the context manager before backward() | 
|  | loss.backward() | 
|  | optimizer.step() | 
|  |  | 
|  | See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling) | 
|  | in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). | 
|  |  | 
|  | :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: | 
|  |  | 
|  | class AutocastModel(nn.Module): | 
|  | ... | 
|  | @autocast() | 
|  | def forward(self, input): | 
|  | ... | 
|  |  | 
|  | Floating-point Tensors produced in an autocast-enabled region may be ``float16``. | 
|  | After returning to an autocast-disabled region, using them with floating-point | 
|  | Tensors of different dtypes may cause type mismatch errors.  If so, cast the Tensor(s) | 
|  | produced in the autocast region back to ``float32`` (or other dtype if desired). | 
|  | If a Tensor from the autocast region is already ``float32``, the cast is a no-op, | 
|  | and incurs no additional overhead. | 
|  | CUDA Example:: | 
|  |  | 
|  | # Creates some tensors in default dtype (here assumed to be float32) | 
|  | a_float32 = torch.rand((8, 8), device="cuda") | 
|  | b_float32 = torch.rand((8, 8), device="cuda") | 
|  | c_float32 = torch.rand((8, 8), device="cuda") | 
|  | d_float32 = torch.rand((8, 8), device="cuda") | 
|  |  | 
|  | with autocast(): | 
|  | # torch.mm is on autocast's list of ops that should run in float16. | 
|  | # Inputs are float32, but the op runs in float16 and produces float16 output. | 
|  | # No manual casts are required. | 
|  | e_float16 = torch.mm(a_float32, b_float32) | 
|  | # Also handles mixed input types | 
|  | f_float16 = torch.mm(d_float32, e_float16) | 
|  |  | 
|  | # After exiting autocast, calls f_float16.float() to use with d_float32 | 
|  | g_float32 = torch.mm(d_float32, f_float16.float()) | 
|  |  | 
|  | CPU Training Example:: | 
|  |  | 
|  | # Creates model and optimizer in default precision | 
|  | model = Net() | 
|  | optimizer = optim.SGD(model.parameters(), ...) | 
|  |  | 
|  | for epoch in epochs: | 
|  | for input, target in data: | 
|  | optimizer.zero_grad() | 
|  |  | 
|  | # Runs the forward pass with autocasting. | 
|  | with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | 
|  | output = model(input) | 
|  | loss = loss_fn(output, target) | 
|  |  | 
|  | loss.backward() | 
|  | optimizer.step() | 
|  |  | 
|  |  | 
|  | CPU Inference Example:: | 
|  |  | 
|  | # Creates model in default precision | 
|  | model = Net().eval() | 
|  |  | 
|  | with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | 
|  | for input in data: | 
|  | # Runs the forward pass with autocasting. | 
|  | output = model(input) | 
|  |  | 
|  | CPU Inference Example with Jit Trace:: | 
|  |  | 
|  | class TestModel(nn.Module): | 
|  | def __init__(self, input_size, num_classes): | 
|  | super(TestModel, self).__init__() | 
|  | self.fc1 = nn.Linear(input_size, num_classes) | 
|  | def forward(self, x): | 
|  | return self.fc1(x) | 
|  |  | 
|  | input_size = 2 | 
|  | num_classes = 2 | 
|  | model = TestModel(input_size, num_classes).eval() | 
|  |  | 
|  | # For now, we suggest to disable the Jit Autocast Pass, | 
|  | # As the issue: https://github.com/pytorch/pytorch/issues/75956 | 
|  | torch._C._jit_set_autocast_mode(False) | 
|  |  | 
|  | with torch.cpu.amp.autocast(cache_enabled=False): | 
|  | model = torch.jit.trace(model, torch.randn(1, input_size)) | 
|  | model = torch.jit.freeze(model) | 
|  | # Models Run | 
|  | for _ in range(3): | 
|  | model(torch.randn(1, input_size)) | 
|  |  | 
|  | Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, | 
|  | please file an issue. | 
|  |  | 
|  | ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. | 
|  | Locally disabling autocast can be useful, for example, if you want to force a subregion | 
|  | to run in a particular ``dtype``.  Disabling autocast gives you explicit control over | 
|  | the execution type.  In the subregion, inputs from the surrounding region | 
|  | should be cast to ``dtype`` before use:: | 
|  |  | 
|  | # Creates some tensors in default dtype (here assumed to be float32) | 
|  | a_float32 = torch.rand((8, 8), device="cuda") | 
|  | b_float32 = torch.rand((8, 8), device="cuda") | 
|  | c_float32 = torch.rand((8, 8), device="cuda") | 
|  | d_float32 = torch.rand((8, 8), device="cuda") | 
|  |  | 
|  | with autocast(): | 
|  | e_float16 = torch.mm(a_float32, b_float32) | 
|  | with autocast(enabled=False): | 
|  | # Calls e_float16.float() to ensure float32 execution | 
|  | # (necessary because e_float16 was created in an autocasted region) | 
|  | f_float32 = torch.mm(c_float32, e_float16.float()) | 
|  |  | 
|  | # No manual casts are required when re-entering the autocast-enabled region. | 
|  | # torch.mm again runs in float16 and produces float16 output, regardless of input types. | 
|  | g_float16 = torch.mm(d_float32, f_float32) | 
|  |  | 
|  | The autocast state is thread-local.  If you want it enabled in a new thread, the context manager or decorator | 
|  | must be invoked in that thread.  This affects :class:`torch.nn.DataParallel` and | 
|  | :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process | 
|  | (see :ref:`Working with Multiple GPUs<amp-multigpu>`). | 
|  |  | 
|  | Args: | 
|  | device_type(string, required):  Whether to use 'cuda' or 'cpu' device | 
|  | enabled(bool, optional, default=True):  Whether autocasting should be enabled in the region. | 
|  | dtype(torch_dtype, optional):  Whether to use torch.float16 or torch.bfloat16. | 
|  | cache_enabled(bool, optional, default=True):  Whether the weight cache inside autocast should be enabled. | 
|  | """ | 
|  | def __init__(self, device_type : str, | 
|  | dtype : Optional[_dtype] = None, | 
|  | enabled : bool = True, | 
|  | cache_enabled : Optional[bool] = None): | 
|  | if torch._jit_internal.is_scripting(): | 
|  | self._enabled = enabled | 
|  | self.device = device_type | 
|  | self.fast_dtype = dtype | 
|  | # TODO: support get_autocast_gpu/cpu_dtype | 
|  | assert dtype is not None | 
|  | return | 
|  | self.device = device_type | 
|  | if self.device == 'cuda': | 
|  | self.fast_dtype = torch.get_autocast_gpu_dtype() | 
|  | elif self.device == 'cpu': | 
|  | self.fast_dtype = torch.get_autocast_cpu_dtype() | 
|  | elif self.device == 'xpu': | 
|  | self.fast_dtype = torch.xpu.get_autocast_xpu_dtype()  # type: ignore[attr-defined] | 
|  | else: | 
|  | raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'') | 
|  | self._cache_enabled = torch.is_autocast_cache_enabled() | 
|  | if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': | 
|  | warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') | 
|  | enabled = False | 
|  | if dtype is not None: | 
|  | self.fast_dtype = dtype | 
|  | if cache_enabled is not None: | 
|  | self._cache_enabled = cache_enabled | 
|  |  | 
|  | if self.device == 'cpu': | 
|  | supported_dtype = [torch.bfloat16] | 
|  | if self.fast_dtype not in supported_dtype: | 
|  | error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' | 
|  | error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' | 
|  | warnings.warn(error_message) | 
|  | enabled = False | 
|  | if self.device == 'xpu': | 
|  | supported_dtype = [torch.bfloat16, torch.float16] | 
|  | if self.fast_dtype not in supported_dtype: | 
|  | error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n' | 
|  | error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.' | 
|  | warnings.warn(error_message) | 
|  | enabled = False | 
|  | if self.device == 'cuda': | 
|  | if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): | 
|  | raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') | 
|  | self._enabled = enabled | 
|  |  | 
|  | def __enter__(self): | 
|  | if torch._jit_internal.is_scripting(): | 
|  | assert self.fast_dtype is not None | 
|  | return self | 
|  |  | 
|  | self.prev_cache_enabled = torch.is_autocast_cache_enabled() | 
|  | if self.device == 'cpu': | 
|  | self.prev = torch.is_autocast_cpu_enabled() | 
|  | self.prev_fastdtype = torch.get_autocast_cpu_dtype() | 
|  | torch.set_autocast_cpu_enabled(self._enabled) | 
|  | torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type] | 
|  | torch.autocast_increment_nesting() | 
|  | elif self.device == 'xpu': | 
|  | self.prev = torch.xpu.is_autocast_xpu_enabled()    # type: ignore[attr-defined] | 
|  | self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype()  # type: ignore[attr-defined] | 
|  | torch.xpu.set_autocast_xpu_enabled(self._enabled)  # type: ignore[attr-defined] | 
|  | torch.xpu.set_autocast_xpu_dtype(self.fast_dtype)  # type: ignore[attr-defined] | 
|  | torch.autocast_increment_nesting() | 
|  | else: | 
|  | self.prev = torch.is_autocast_enabled() | 
|  | self.prev_fastdtype = torch.get_autocast_gpu_dtype() | 
|  | torch.set_autocast_gpu_dtype(self.fast_dtype)  # type: ignore[arg-type] | 
|  | torch.set_autocast_enabled(self._enabled) | 
|  | torch.autocast_increment_nesting() | 
|  | torch.set_autocast_cache_enabled(self._cache_enabled) | 
|  |  | 
|  | def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override] | 
|  | if torch._jit_internal.is_scripting(): | 
|  | return | 
|  |  | 
|  | # Drop the cache when we exit to a nesting level that's outside any instance of autocast. | 
|  | if self.device == 'cpu': | 
|  | if torch.autocast_decrement_nesting() == 0: | 
|  | torch.clear_autocast_cache() | 
|  | torch.set_autocast_cpu_enabled(self.prev) | 
|  | torch.set_autocast_cpu_dtype(self.prev_fastdtype) | 
|  | elif self.device == 'xpu': | 
|  | if torch.autocast_decrement_nesting() == 0: | 
|  | torch.clear_autocast_cache() | 
|  | torch.xpu.set_autocast_xpu_enabled(self.prev)            # type: ignore[attr-defined] | 
|  | torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype)    # type: ignore[attr-defined] | 
|  | else: | 
|  | if torch.autocast_decrement_nesting() == 0: | 
|  | torch.clear_autocast_cache() | 
|  | torch.set_autocast_enabled(self.prev) | 
|  | torch.set_autocast_gpu_dtype(self.prev_fastdtype) | 
|  | torch.set_autocast_cache_enabled(self.prev_cache_enabled) | 
|  | return False | 
|  |  | 
|  | def __call__(self, func): | 
|  | if torch._jit_internal.is_scripting(): | 
|  | return func | 
|  | return autocast_decorator(self, func) |