blob: bb04e2dde3aa3a88ad737064f466e5f766583451 [file] [log] [blame]
from typing import List, Optional
import torch.distributed.rpc as rpc
import torch.optim as optim
import torch.jit as jit
from torch import Tensor
from torch.distributed.rpc import RRef
from .functional_adagrad import _FunctionalAdagrad
import torch.distributed.autograd as dist_autograd
from collections import defaultdict
from threading import Lock
# XXX: we define a _ScriptModuleOptimizer here to explicitly
# compile the FunctionalOptimizer class into TorchScript
# This is because ScriptClass instance still lives in
# python unless you explictly compile it as an attribute
# in ScriptModule or pass it to a ScriptFunction
# _ScriptLocalOptimizerInterface serves as a common
# interface type for Optimizer ScriptModules.
#
# TODO (wanchaol): remove this once we added TorchScript
# class reference semantics
@jit.interface
class _ScriptLocalOptimizerInterface(object):
def step(self, autograd_ctx_id: int) -> None:
pass
class _ScriptLocalOptimizer(jit.ScriptModule):
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
super().__init__()
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(
self._local_params,
*args,
**kwargs)
@jit.script_method
def step(self, autograd_ctx_id: int):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: List[Optional[Tensor]] = [
all_local_grads[p] if p in all_local_grads else None
for p in self._local_params
]
self.optim.step(grads)
class _LocalOptimizer(object):
# Ideally we would only need to share a lock for instances of
# _LocalOptimizer that deal with the same parameters. We are
# making a simplifying assumption here that if there is more
# than one instance of _LocalOptimizer per worker, they will
# be optimizing the same parameters (e.g. each data parallel
# trainer will create its own instance of _LocalOptimizer but
# they will all optimize the same parameters on each worker)
global_lock = Lock()
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(
self._local_params,
*args,
**kwargs)
def step(self, autograd_ctx_id):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
with _LocalOptimizer.global_lock:
for param, grad in all_local_grads.items():
param.grad = grad
self.optim.step()
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
return rpc.RRef(
_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
local_optim = local_optim_rref.local_value()
local_optim.step(autograd_ctx_id)
# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
return rpc.RRef(
_ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs), _ScriptLocalOptimizerInterface)
@jit.script
def _script_local_optimizer_step(
local_optim_rref: RRef[_ScriptLocalOptimizerInterface],
autograd_ctx_id: int
) -> None:
local_optim = local_optim_rref.local_value()
local_optim.step(autograd_ctx_id)
def _wait_for_all(rpc_futs):
# TODO: improve error propagation
exception = None
results = []
for fut in rpc_futs:
try:
results.append(fut.wait())
except Exception as e:
results.append(e)
exception = e
if exception is not None:
raise exception
return results
class DistributedOptimizer:
"""
DistributedOptimizer takes remote references to parameters scattered
across workers and applies the given optimizer locally for each parameter.
This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
to retrieve the gradients for specific parameters.
Concurrent calls to
:meth:`~torch.distributed.optim.DistributedOptimizer.step`,
either from the same or different clients, will
be serialized on each worker -- as each worker's optimizer can only work
on one set of gradients at a time. However, there is no guarantee that
the full forward-backward-optimizer sequence will execute for one client
at a time. This means that the gradients being applied may not correspond
to the latest forward pass executed on a given worker. Also, there is no
guaranteed ordering across workers.
Args:
optimizer_class (optim.Optimizer): the class of optimizer to
instantiate on each worker.
params_rref (list[RRef]): list of RRefs to local or remote parameters
to optimize.
args: arguments to pass to the optimizer constructor on each worker.
kwargs: arguments to pass to the optimizer constructor on each worker.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>> # Forward pass.
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> loss = rref1.to_here() + rref2.to_here()
>>>
>>> # Backward pass.
>>> dist_autograd.backward(context_id, [loss.sum()])
>>>
>>> # Optimizer.
>>> dist_optim = DistributedOptimizer(
>>> optim.SGD,
>>> [rref1, rref2],
>>> lr=0.05,
>>> )
>>> dist_optim.step(context_id)
"""
# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
}
def __init__(self, optimizer_class, params_rref, *args, **kwargs):
per_worker_params_rref = defaultdict(list)
for param in params_rref:
per_worker_params_rref[param.owner()].append(param)
optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class, optimizer_class)
self.is_functional_optim = (optim_ctor != optimizer_class)
if self.is_functional_optim:
optimizer_new_func = _new_script_local_optimizer
else:
optimizer_new_func = _new_local_optimizer
remote_optim_futs = []
for worker, param_rrefs in per_worker_params_rref.items():
remote_optim_rref_fut = rpc.rpc_async(
worker,
optimizer_new_func,
args=(optim_ctor, param_rrefs) + args,
kwargs=kwargs,
)
remote_optim_futs.append(remote_optim_rref_fut)
self.remote_optimizers = _wait_for_all(remote_optim_futs)
def step(self, context_id):
"""
Performs a single optimization step.
This will call :meth:`torch.optim.Optimizer.step` on each worker
containing parameters to be optimized, and will block until all workers
return. The provided ``context_id`` will be used to retrieve the
corresponding :class:`~torch.distributed.autograd.context` that
contains the gradients that should be applied to the parameters.
Args:
context_id: the autograd context id for which we should run the
optimizer step.
"""
dist_autograd._is_valid_context(context_id)
if self.is_functional_optim:
optimizer_step_func = _script_local_optimizer_step
else:
optimizer_step_func = _local_optimizer_step
rpc_futs = []
for optimizer in self.remote_optimizers:
rpc_futs.append(rpc.rpc_async(
optimizer.owner(),
optimizer_step_func,
args=(optimizer, context_id),
))
_wait_for_all(rpc_futs)