|  | from collections import defaultdict, abc as container_abcs | 
|  | import torch | 
|  | from copy import deepcopy | 
|  | from itertools import chain | 
|  | import warnings | 
|  | import functools | 
|  |  | 
|  | __all__ = ['Optimizer'] | 
|  |  | 
|  | class _RequiredParameter(object): | 
|  | """Singleton class representing a required parameter for an Optimizer.""" | 
|  | def __repr__(self): | 
|  | return "<required parameter>" | 
|  |  | 
|  | required = _RequiredParameter() | 
|  |  | 
|  |  | 
|  | def _use_grad_for_differentiable(func): | 
|  | def _use_grad(self, *args, **kwargs): | 
|  | prev_grad = torch.is_grad_enabled() | 
|  | try: | 
|  | torch.set_grad_enabled(self.defaults['differentiable']) | 
|  | ret = func(self, *args, **kwargs) | 
|  | finally: | 
|  | torch.set_grad_enabled(prev_grad) | 
|  | return ret | 
|  | return _use_grad | 
|  |  | 
|  |  | 
|  | class Optimizer(object): | 
|  | r"""Base class for all optimizers. | 
|  |  | 
|  | .. warning:: | 
|  | Parameters need to be specified as collections that have a deterministic | 
|  | ordering that is consistent between runs. Examples of objects that don't | 
|  | satisfy those properties are sets and iterators over values of dictionaries. | 
|  |  | 
|  | Args: | 
|  | params (iterable): an iterable of :class:`torch.Tensor` s or | 
|  | :class:`dict` s. Specifies what Tensors should be optimized. | 
|  | defaults: (dict): a dict containing default values of optimization | 
|  | options (used when a parameter group doesn't specify them). | 
|  | """ | 
|  |  | 
|  | def __init__(self, params, defaults): | 
|  | torch._C._log_api_usage_once("python.optimizer") | 
|  | self.defaults = defaults | 
|  |  | 
|  | self._hook_for_profile() | 
|  |  | 
|  | if isinstance(params, torch.Tensor): | 
|  | raise TypeError("params argument given to the optimizer should be " | 
|  | "an iterable of Tensors or dicts, but got " + | 
|  | torch.typename(params)) | 
|  |  | 
|  | self.state = defaultdict(dict) | 
|  | self.param_groups = [] | 
|  |  | 
|  | param_groups = list(params) | 
|  | if len(param_groups) == 0: | 
|  | raise ValueError("optimizer got an empty parameter list") | 
|  | if not isinstance(param_groups[0], dict): | 
|  | param_groups = [{'params': param_groups}] | 
|  |  | 
|  | for param_group in param_groups: | 
|  | self.add_param_group(param_group) | 
|  |  | 
|  | # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, | 
|  | # which I don't think exists | 
|  | # https://github.com/pytorch/pytorch/issues/72948 | 
|  | self._warned_capturable_if_run_uncaptured = True | 
|  |  | 
|  |  | 
|  | def __getstate__(self): | 
|  | return { | 
|  | 'defaults': self.defaults, | 
|  | 'state': self.state, | 
|  | 'param_groups': self.param_groups, | 
|  | } | 
|  |  | 
|  | def __setstate__(self, state): | 
|  | self.__dict__.update(state) | 
|  | self._hook_for_profile()  # To support multiprocessing pickle/unpickle. | 
|  | self.defaults.setdefault('differentiable', False) | 
|  |  | 
|  | def __repr__(self): | 
|  | format_string = self.__class__.__name__ + ' (' | 
|  | for i, group in enumerate(self.param_groups): | 
|  | format_string += '\n' | 
|  | format_string += 'Parameter Group {0}\n'.format(i) | 
|  | for key in sorted(group.keys()): | 
|  | if key != 'params': | 
|  | format_string += '    {0}: {1}\n'.format(key, group[key]) | 
|  | format_string += ')' | 
|  | return format_string | 
|  |  | 
|  | # Currently needed by Adam and AdamW | 
|  | def _cuda_graph_capture_health_check(self): | 
|  | if torch.has_cuda and torch.cuda.is_available(): | 
|  | capturing = torch.cuda.is_current_stream_capturing() | 
|  |  | 
|  | if capturing and not self.defaults['capturable']: | 
|  | raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " + | 
|  | self.__class__.__name__ + | 
|  | " but this instance was constructed with capturable=False.") | 
|  |  | 
|  | if ( | 
|  | (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) | 
|  | and self.defaults["capturable"] | 
|  | and (not capturing) | 
|  | ): | 
|  | print("Warning: This instance was constructed with capturable=True, but step() " + | 
|  | "is running without CUDA graph capture. If you never intend to graph-capture this " + | 
|  | "instance, capturable=True can impair performance, and you should set capturable=False.") | 
|  | self._warned_capturable_if_run_uncaptured = True | 
|  |  | 
|  | def _hook_for_profile(self): | 
|  | self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) | 
|  |  | 
|  | def profile_hook_step(func): | 
|  |  | 
|  | @functools.wraps(func) | 
|  | def wrapper(*args, **kwargs): | 
|  | obj, *_ = args | 
|  | profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) | 
|  | with torch.autograd.profiler.record_function(profile_name): | 
|  | return func(*args, **kwargs) | 
|  | return wrapper | 
|  |  | 
|  | hooked = getattr(self.__class__.step, "hooked", None) | 
|  | if not hooked: | 
|  | self.__class__.step = profile_hook_step(self.__class__.step) | 
|  | self.__class__.step.hooked = True | 
|  |  | 
|  | def state_dict(self): | 
|  | r"""Returns the state of the optimizer as a :class:`dict`. | 
|  |  | 
|  | It contains two entries: | 
|  |  | 
|  | * state - a dict holding current optimization state. Its content | 
|  | differs between optimizer classes. | 
|  | * param_groups - a list containing all parameter groups where each | 
|  | parameter group is a dict | 
|  | """ | 
|  | # Save order indices instead of Tensors | 
|  | param_mappings = {} | 
|  | start_index = 0 | 
|  |  | 
|  | def pack_group(group): | 
|  | nonlocal start_index | 
|  | packed = {k: v for k, v in group.items() if k != 'params'} | 
|  | param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) | 
|  | if id(p) not in param_mappings}) | 
|  | packed['params'] = [param_mappings[id(p)] for p in group['params']] | 
|  | start_index += len(packed['params']) | 
|  | return packed | 
|  | param_groups = [pack_group(g) for g in self.param_groups] | 
|  | # Remap state to use order indices as keys | 
|  | packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v | 
|  | for k, v in self.state.items()} | 
|  | return { | 
|  | 'state': packed_state, | 
|  | 'param_groups': param_groups, | 
|  | } | 
|  |  | 
|  | def load_state_dict(self, state_dict): | 
|  | r"""Loads the optimizer state. | 
|  |  | 
|  | Args: | 
|  | state_dict (dict): optimizer state. Should be an object returned | 
|  | from a call to :meth:`state_dict`. | 
|  | """ | 
|  | # deepcopy, to be consistent with module API | 
|  | state_dict = deepcopy(state_dict) | 
|  | # Validate the state_dict | 
|  | groups = self.param_groups | 
|  | saved_groups = state_dict['param_groups'] | 
|  |  | 
|  | if len(groups) != len(saved_groups): | 
|  | raise ValueError("loaded state dict has a different number of " | 
|  | "parameter groups") | 
|  | param_lens = (len(g['params']) for g in groups) | 
|  | saved_lens = (len(g['params']) for g in saved_groups) | 
|  | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | 
|  | raise ValueError("loaded state dict contains a parameter group " | 
|  | "that doesn't match the size of optimizer's group") | 
|  |  | 
|  | # Update the state | 
|  | id_map = {old_id: p for old_id, p in | 
|  | zip(chain.from_iterable((g['params'] for g in saved_groups)), | 
|  | chain.from_iterable((g['params'] for g in groups)))} | 
|  |  | 
|  | def cast(param, value, key=None): | 
|  | r"""Make a deep copy of value, casting all tensors to device of param.""" | 
|  | if isinstance(value, torch.Tensor): | 
|  | # Floating-point types are a bit special here. They are the only ones | 
|  | # that are assumed to always match the type of params. | 
|  | # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 | 
|  | if (key != "step"): | 
|  | if param.is_floating_point(): | 
|  | value = value.to(param.dtype) | 
|  | value = value.to(param.device) | 
|  | return value | 
|  | elif isinstance(value, dict): | 
|  | return {k: cast(param, v, key=k) for k, v in value.items()} | 
|  | elif isinstance(value, container_abcs.Iterable): | 
|  | return type(value)(cast(param, v) for v in value) | 
|  | else: | 
|  | return value | 
|  |  | 
|  | # Copy state assigned to params (and cast tensors to appropriate types). | 
|  | # State that is not assigned to params is copied as is (needed for | 
|  | # backward compatibility). | 
|  | state = defaultdict(dict) | 
|  | for k, v in state_dict['state'].items(): | 
|  | if k in id_map: | 
|  | param = id_map[k] | 
|  | state[param] = cast(param, v) | 
|  | else: | 
|  | state[k] = v | 
|  |  | 
|  | # Update parameter groups, setting their 'params' value | 
|  | def update_group(group, new_group): | 
|  | new_group['params'] = group['params'] | 
|  | return new_group | 
|  | param_groups = [ | 
|  | update_group(g, ng) for g, ng in zip(groups, saved_groups)] | 
|  | self.__setstate__({'state': state, 'param_groups': param_groups}) | 
|  |  | 
|  | def zero_grad(self, set_to_none: bool = False): | 
|  | r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. | 
|  |  | 
|  | Args: | 
|  | set_to_none (bool): instead of setting to zero, set the grads to None. | 
|  | This will in general have lower memory footprint, and can modestly improve performance. | 
|  | However, it changes certain behaviors. For example: | 
|  | 1. When the user tries to access a gradient and perform manual ops on it, | 
|  | a None attribute or a Tensor full of 0s will behave differently. | 
|  | 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s | 
|  | are guaranteed to be None for params that did not receive a gradient. | 
|  | 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None | 
|  | (in one case it does the step with a gradient of 0 and in the other it skips | 
|  | the step altogether). | 
|  | """ | 
|  | foreach = self.defaults.get('foreach', False) | 
|  |  | 
|  | if not hasattr(self, "_zero_grad_profile_name"): | 
|  | self._hook_for_profile() | 
|  | if foreach: | 
|  | per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) | 
|  | with torch.autograd.profiler.record_function(self._zero_grad_profile_name): | 
|  | for group in self.param_groups: | 
|  | for p in group['params']: | 
|  | if p.grad is not None: | 
|  | if set_to_none: | 
|  | p.grad = None | 
|  | else: | 
|  | if p.grad.grad_fn is not None: | 
|  | p.grad.detach_() | 
|  | else: | 
|  | p.grad.requires_grad_(False) | 
|  | if (not foreach or p.grad.is_sparse): | 
|  | p.grad.zero_() | 
|  | else: | 
|  | per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) | 
|  | if foreach: | 
|  | for _, per_dtype_grads in per_device_and_dtype_grads.items(): | 
|  | for grads in per_dtype_grads.values(): | 
|  | torch._foreach_zero_(grads) | 
|  |  | 
|  | def step(self, closure): | 
|  | r"""Performs a single optimization step (parameter update). | 
|  |  | 
|  | Args: | 
|  | closure (Callable): A closure that reevaluates the model and | 
|  | returns the loss. Optional for most optimizers. | 
|  |  | 
|  | .. note:: | 
|  | Unless otherwise specified, this function should not modify the | 
|  | ``.grad`` field of the parameters. | 
|  | """ | 
|  | raise NotImplementedError | 
|  |  | 
|  | def add_param_group(self, param_group): | 
|  | r"""Add a param group to the :class:`Optimizer` s `param_groups`. | 
|  |  | 
|  | This can be useful when fine tuning a pre-trained network as frozen layers can be made | 
|  | trainable and added to the :class:`Optimizer` as training progresses. | 
|  |  | 
|  | Args: | 
|  | param_group (dict): Specifies what Tensors should be optimized along with group | 
|  | specific optimization options. | 
|  | """ | 
|  | assert isinstance(param_group, dict), "param group must be a dict" | 
|  |  | 
|  | params = param_group['params'] | 
|  | if isinstance(params, torch.Tensor): | 
|  | param_group['params'] = [params] | 
|  | elif isinstance(params, set): | 
|  | raise TypeError('optimizer parameters need to be organized in ordered collections, but ' | 
|  | 'the ordering of tensors in sets will change between runs. Please use a list instead.') | 
|  | else: | 
|  | param_group['params'] = list(params) | 
|  |  | 
|  | for param in param_group['params']: | 
|  | if not isinstance(param, torch.Tensor): | 
|  | raise TypeError("optimizer can only optimize Tensors, " | 
|  | "but one of the params is " + torch.typename(param)) | 
|  | if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad): | 
|  | raise ValueError("can't optimize a non-leaf Tensor") | 
|  |  | 
|  | for name, default in self.defaults.items(): | 
|  | if default is required and name not in param_group: | 
|  | raise ValueError("parameter group didn't specify a value of required optimization parameter " + | 
|  | name) | 
|  | else: | 
|  | param_group.setdefault(name, default) | 
|  |  | 
|  | params = param_group['params'] | 
|  | if len(params) != len(set(params)): | 
|  | warnings.warn("optimizer contains a parameter group with duplicate parameters; " | 
|  | "in future, this will cause an error; " | 
|  | "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3) | 
|  |  | 
|  | param_set = set() | 
|  | for group in self.param_groups: | 
|  | param_set.update(set(group['params'])) | 
|  |  | 
|  | if not param_set.isdisjoint(set(param_group['params'])): | 
|  | raise ValueError("some parameters appear in more than one parameter group") | 
|  |  | 
|  | self.param_groups.append(param_group) |