| from collections import defaultdict, abc as container_abcs |
| |
| import torch |
| from copy import deepcopy |
| from itertools import chain |
| import warnings |
| import functools |
| |
| |
| class _RequiredParameter(object): |
| """Singleton class representing a required parameter for an Optimizer.""" |
| def __repr__(self): |
| return "<required parameter>" |
| |
| required = _RequiredParameter() |
| |
| |
| class Optimizer(object): |
| r"""Base class for all optimizers. |
| |
| .. warning:: |
| Parameters need to be specified as collections that have a deterministic |
| ordering that is consistent between runs. Examples of objects that don't |
| satisfy those properties are sets and iterators over values of dictionaries. |
| |
| Args: |
| params (iterable): an iterable of :class:`torch.Tensor` s or |
| :class:`dict` s. Specifies what Tensors should be optimized. |
| defaults: (dict): a dict containing default values of optimization |
| options (used when a parameter group doesn't specify them). |
| """ |
| |
| def __init__(self, params, defaults): |
| torch._C._log_api_usage_once("python.optimizer") |
| self.defaults = defaults |
| |
| self._hook_for_profile() |
| |
| if isinstance(params, torch.Tensor): |
| raise TypeError("params argument given to the optimizer should be " |
| "an iterable of Tensors or dicts, but got " + |
| torch.typename(params)) |
| |
| self.state = defaultdict(dict) |
| self.param_groups = [] |
| |
| param_groups = list(params) |
| if len(param_groups) == 0: |
| raise ValueError("optimizer got an empty parameter list") |
| if not isinstance(param_groups[0], dict): |
| param_groups = [{'params': param_groups}] |
| |
| for param_group in param_groups: |
| self.add_param_group(param_group) |
| |
| def __getstate__(self): |
| return { |
| 'defaults': self.defaults, |
| 'state': self.state, |
| 'param_groups': self.param_groups, |
| } |
| |
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| self._hook_for_profile() # To support multiprocessing pickle/unpickle. |
| |
| def __repr__(self): |
| format_string = self.__class__.__name__ + ' (' |
| for i, group in enumerate(self.param_groups): |
| format_string += '\n' |
| format_string += 'Parameter Group {0}\n'.format(i) |
| for key in sorted(group.keys()): |
| if key != 'params': |
| format_string += ' {0}: {1}\n'.format(key, group[key]) |
| format_string += ')' |
| return format_string |
| |
| def _hook_for_profile(self): |
| self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) |
| |
| def profile_hook_step(func): |
| |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| obj, *_ = args |
| profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) |
| with torch.autograd.profiler.record_function(profile_name): |
| return func(*args, **kwargs) |
| return wrapper |
| |
| hooked = getattr(self.__class__.step, "hooked", None) |
| if not hooked: |
| self.__class__.step = profile_hook_step(self.__class__.step) |
| self.__class__.step.hooked = True |
| |
| def state_dict(self): |
| r"""Returns the state of the optimizer as a :class:`dict`. |
| |
| It contains two entries: |
| |
| * state - a dict holding current optimization state. Its content |
| differs between optimizer classes. |
| * param_groups - a list containing all parameter groups where each |
| parameter group is a dict |
| """ |
| # Save order indices instead of Tensors |
| param_mappings = {} |
| start_index = 0 |
| |
| def pack_group(group): |
| nonlocal start_index |
| packed = {k: v for k, v in group.items() if k != 'params'} |
| param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) |
| if id(p) not in param_mappings}) |
| packed['params'] = [param_mappings[id(p)] for p in group['params']] |
| start_index += len(packed['params']) |
| return packed |
| param_groups = [pack_group(g) for g in self.param_groups] |
| # Remap state to use order indices as keys |
| packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v |
| for k, v in self.state.items()} |
| return { |
| 'state': packed_state, |
| 'param_groups': param_groups, |
| } |
| |
| def load_state_dict(self, state_dict): |
| r"""Loads the optimizer state. |
| |
| Args: |
| state_dict (dict): optimizer state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| # deepcopy, to be consistent with module API |
| state_dict = deepcopy(state_dict) |
| # Validate the state_dict |
| groups = self.param_groups |
| saved_groups = state_dict['param_groups'] |
| |
| if len(groups) != len(saved_groups): |
| raise ValueError("loaded state dict has a different number of " |
| "parameter groups") |
| param_lens = (len(g['params']) for g in groups) |
| saved_lens = (len(g['params']) for g in saved_groups) |
| if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): |
| raise ValueError("loaded state dict contains a parameter group " |
| "that doesn't match the size of optimizer's group") |
| |
| # Update the state |
| id_map = {old_id: p for old_id, p in |
| zip(chain.from_iterable((g['params'] for g in saved_groups)), |
| chain.from_iterable((g['params'] for g in groups)))} |
| |
| def cast(param, value): |
| r"""Make a deep copy of value, casting all tensors to device of param.""" |
| if isinstance(value, torch.Tensor): |
| # Floating-point types are a bit special here. They are the only ones |
| # that are assumed to always match the type of params. |
| if param.is_floating_point(): |
| value = value.to(param.dtype) |
| value = value.to(param.device) |
| return value |
| elif isinstance(value, dict): |
| return {k: cast(param, v) for k, v in value.items()} |
| elif isinstance(value, container_abcs.Iterable): |
| return type(value)(cast(param, v) for v in value) |
| else: |
| return value |
| |
| # Copy state assigned to params (and cast tensors to appropriate types). |
| # State that is not assigned to params is copied as is (needed for |
| # backward compatibility). |
| state = defaultdict(dict) |
| for k, v in state_dict['state'].items(): |
| if k in id_map: |
| param = id_map[k] |
| state[param] = cast(param, v) |
| else: |
| state[k] = v |
| |
| # Update parameter groups, setting their 'params' value |
| def update_group(group, new_group): |
| new_group['params'] = group['params'] |
| return new_group |
| param_groups = [ |
| update_group(g, ng) for g, ng in zip(groups, saved_groups)] |
| self.__setstate__({'state': state, 'param_groups': param_groups}) |
| |
| def zero_grad(self, set_to_none: bool = False): |
| r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. |
| |
| Args: |
| set_to_none (bool): instead of setting to zero, set the grads to None. |
| This will in general have lower memory footprint, and can modestly improve performance. |
| However, it changes certain behaviors. For example: |
| 1. When the user tries to access a gradient and perform manual ops on it, |
| a None attribute or a Tensor full of 0s will behave differently. |
| 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s |
| are guaranteed to be None for params that did not receive a gradient. |
| 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None |
| (in one case it does the step with a gradient of 0 and in the other it skips |
| the step altogether). |
| """ |
| if not hasattr(self, "_zero_grad_profile_name"): |
| self._hook_for_profile() |
| with torch.autograd.profiler.record_function(self._zero_grad_profile_name): |
| for group in self.param_groups: |
| for p in group['params']: |
| if p.grad is not None: |
| if set_to_none: |
| p.grad = None |
| else: |
| if p.grad.grad_fn is not None: |
| p.grad.detach_() |
| else: |
| p.grad.requires_grad_(False) |
| p.grad.zero_() |
| |
| def step(self, closure): |
| r"""Performs a single optimization step (parameter update). |
| |
| Args: |
| closure (callable): A closure that reevaluates the model and |
| returns the loss. Optional for most optimizers. |
| |
| .. note:: |
| Unless otherwise specified, this function should not modify the |
| ``.grad`` field of the parameters. |
| """ |
| raise NotImplementedError |
| |
| def add_param_group(self, param_group): |
| r"""Add a param group to the :class:`Optimizer` s `param_groups`. |
| |
| This can be useful when fine tuning a pre-trained network as frozen layers can be made |
| trainable and added to the :class:`Optimizer` as training progresses. |
| |
| Args: |
| param_group (dict): Specifies what Tensors should be optimized along with group |
| specific optimization options. |
| """ |
| assert isinstance(param_group, dict), "param group must be a dict" |
| |
| params = param_group['params'] |
| if isinstance(params, torch.Tensor): |
| param_group['params'] = [params] |
| elif isinstance(params, set): |
| raise TypeError('optimizer parameters need to be organized in ordered collections, but ' |
| 'the ordering of tensors in sets will change between runs. Please use a list instead.') |
| else: |
| param_group['params'] = list(params) |
| |
| for param in param_group['params']: |
| if not isinstance(param, torch.Tensor): |
| raise TypeError("optimizer can only optimize Tensors, " |
| "but one of the params is " + torch.typename(param)) |
| if not param.is_leaf: |
| raise ValueError("can't optimize a non-leaf Tensor") |
| |
| for name, default in self.defaults.items(): |
| if default is required and name not in param_group: |
| raise ValueError("parameter group didn't specify a value of required optimization parameter " + |
| name) |
| else: |
| param_group.setdefault(name, default) |
| |
| params = param_group['params'] |
| if len(params) != len(set(params)): |
| warnings.warn("optimizer contains a parameter group with duplicate parameters; " |
| "in future, this will cause an error; " |
| "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3) |
| |
| param_set = set() |
| for group in self.param_groups: |
| param_set.update(set(group['params'])) |
| |
| if not param_set.isdisjoint(set(param_group['params'])): |
| raise ValueError("some parameters appear in more than one parameter group") |
| |
| self.param_groups.append(param_group) |