blob: 17ab2df82f61a4f173cd6c885161f55741c2d222 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from types import FunctionType
from functools import wraps
import six
from caffe2.python import cnn, dyndep, scope, workspace, core
from caffe2.proto import caffe2_pb2
dyndep.InitOpsLibrary("@/caffe2/caffe2/contrib/nccl:nccl_ops")
DATAPARALLEL_OPS = [
"Conv",
"ConvTranspose",
"GroupConv",
"FC",
"FC_Decomp",
"FC_Prune",
"FC_Sparse",
"LRN",
"Dropout",
"MaxPool",
"AveragePool",
"Concat",
"DepthConcat",
"Relu",
"Transpose",
"SpatialBN",
"Accuracy",
"Adam",
"AveragedLoss",
"Cast",
"LabelCrossEntropy",
"LearningRate",
"Print",
"Scale",
"Snapshot",
"Softmax",
"StopGradient",
"Summarize",
"Sum",
"Tanh",
"WeightedSum",
"SquaredL2Distance",
]
class _GPUDataParallelMetaClass(type):
"""A meta class to patch method in order to distribute them over multiple
GPUs.
"""
_devices = []
@staticmethod
def _data_parallel_wrapper(op):
@wraps(op)
def wrapped(cls, blob_in, blob_out, *args, **kwargs):
# Helpers to extract a device specific blob or a global blob
def self_or_item(d, key):
if isinstance(d, dict):
assert key in d
return d[key]
return d
def get_input(gpu_id):
if isinstance(blob_in, list):
return [self_or_item(blob, gpu_id) for blob in blob_in]
return self_or_item(blob_in, gpu_id)
def get_output(gpu_id):
return self_or_item(blob_out, gpu_id)
# If we have explicit device scope, we do not parallelize
if cls.explicit_scope():
return op(
cls,
blob_in,
blob_out,
*args,
**kwargs)
devices = _GPUDataParallelMetaClass._devices
results = {}
for gpu_id in devices:
with core.NameScope("gpu_{}".format(gpu_id)):
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
result = op(
cls,
get_input(gpu_id),
get_output(gpu_id),
*args,
**kwargs)
results[gpu_id] = result
return results
return wrapped
def __new__(meta, classname, bases, class_dict):
assert len(bases) == 1, "Expects only one base class"
base = bases[0]
assert base is cnn.CNNModelHelper, "Base class should be CNNModelHelper"
new_class_dict = {}
for name, attr in base.__dict__.items():
if name not in DATAPARALLEL_OPS:
continue
attr = _GPUDataParallelMetaClass._data_parallel_wrapper(attr)
new_class_dict[name] = attr
for name, attr in class_dict.items():
if name in new_class_dict:
continue
if isinstance(attr, FunctionType):
if name in DATAPARALLEL_OPS:
new_class_dict[name] = \
_GPUDataParallelMetaClass._data_parallel_wrapper(attr)
else:
new_class_dict[name] = attr
return super(_GPUDataParallelMetaClass, meta).__new__(
meta, classname, bases, new_class_dict)
@six.add_metaclass(_GPUDataParallelMetaClass)
class GPUDataParallelModel(cnn.CNNModelHelper):
"""A helper class that extends CNNModelHelper to support multi GPUs
data parallel training.
"""
def __init__(self, devices, *args, **kwargs):
assert len(devices) >= 1, "Should have at least 1 GPU devices"
assert len(devices) <= workspace.NumCudaDevices(), \
"Requested # of devices {} is greater than the # of GPUs {}".\
format(devices, workspace.NumCudaDevices())
_GPUDataParallelMetaClass._devices = devices
self._devices = devices
self._explicit_scope = False
self._gradient_reduce_all_added = False
self._mpi_comm = None
super(GPUDataParallelModel, self).__init__(*args, **kwargs)
def explicit_scope(self):
return self._explicit_scope
def _call(self, name, *args, **kwargs):
return super(GPUDataParallelModel, self).__getattr__(
name)(*args, **kwargs)
# TODO(denisy): try out decorators to avoid this code below
def Accuracy(self, *args, **kwargs):
return self._call("Accuracy", *args, **kwargs)
def Adam(self, *args, **kwargs):
return self._call("Adam", *args, **kwargs)
def AveragedLoss(self, *args, **kwargs):
return self._call("AveragedLoss", *args, **kwargs)
def Cast(self, *args, **kwargs):
return self._call("Cast", *args, **kwargs)
def LabelCrossEntropy(self, *args, **kwargs):
return self._call("LabelCrossEntropy", *args, **kwargs)
def LearningRate(self, *args, **kwargs):
return self._call("LearningRate", *args, **kwargs)
def Print(self, *args, **kwargs):
return self._call("Print", *args, **kwargs)
def Scale(self, *args, **kwargs):
return self._call("Scale", *args, **kwargs)
def Snapshot(self, *args, **kwargs):
return self._call("Snapshot", *args, **kwargs)
def Softmax(self, *args, **kwargs):
return self._call("Softmax", *args, **kwargs)
def StopGradient(self, *args, **kwargs):
return self._call("StopGradient", *args, **kwargs)
def Sum(self, *args, **kwargs):
return self._call("Sum", *args, **kwargs)
def Summarize(self, *args, **kwargs):
return self._call("Summarize", *args, **kwargs)
def Tanh(self, *args, **kwargs):
return self._call("Tanh", *args, **kwargs)
def WeightedSum(self, *args, **kwargs):
return self._call("WeightedSum", *args, **kwargs)
def SquaredL2Distance(self, *args, **kwargs):
return self._call("SquaredL2Distance", *args, **kwargs)
def SetMPIComm(self, mpi_comm):
self._mpi_comm = mpi_comm
def FinalizeSetup(self):
self.param_init_net.RunAllOnGPU()
self.RunAllOnGPU()
# If MPI enabled, broadcast params from master
if (self._mpi_comm is not None):
self._AddMPIParameterSync()
# Setup sync of initial params
self._SyncInitialParams()
def AddGradientOperators(self, params, *args, **kwargs):
def create_grad(param):
return self.ConstantFill(param, str(param) + "_grad", value=1.0)
param_grad = {}
# Explicitly need to create gradients on each GPU
for param in params:
if not isinstance(param, dict):
grad = create_grad(param)
param_grad[str(param)] = str(grad)
else:
for gpu_id in self._devices:
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
assert gpu_id in param
p = param[gpu_id]
g = create_grad(p)
param_grad[str(p)] = str(g)
return super(GPUDataParallelModel, self).AddGradientOperators(
param_grad, *args, **kwargs)
def AddWeightDecay(self, weight_decay):
if weight_decay == 0.0:
return
assert(weight_decay > 0.0)
self._explicit_scope = True
assert \
self._gradient_reduce_all_added, \
"Weight decay must be done after gradient sync between gpus"
for gpu_id in self._devices:
with core.NameScope("gpu_{}".format(gpu_id)):
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
value=weight_decay)
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1],
value=1.0)
# Only update parameters that belong to the current GPU
params = self._CurrentScopeParams()
# Take only params that are weights
print("Adding weigth-decay for gpu {}.".format(gpu_id))
gpu_weights = [p for p in params if p in self.weights]
for w in gpu_weights:
# Equivalent to grad -= w * param
grad = self.param_to_grad[w]
self.net.WeightedSum([grad, ONE, w, wd], grad)
self._explicit_scope = False
def _Broadcast(self, net, param):
# TODO(akyrola): replace with NCCLBroadcast when it's working
# Copy params from gpu_0 to other
for gpu_idx in self._devices[1:]:
device_opt = core.DeviceOption(caffe2_pb2.CUDA, gpu_idx)
with core.DeviceScope(device_opt):
net.Copy(
"gpu_{}/{}".format(self._devices[0], param),
"gpu_{}/{}".format(gpu_idx, param)
)
def _SyncInitialParams(self):
unique_param_names = set(
stripParamName(p)
for p in self.params
)
self._explicit_scope = True
for param in unique_param_names:
self._Broadcast(self.param_init_net, param)
self._explicit_scope = False
def _AddMPIParameterSync(self):
# Sync from master
unique_param_names = set(
stripParamName(p)
for p in self.params
)
self._explicit_scope = True
# Should this be done in GPU 0 scope?
for param_name in unique_param_names:
param = "gpu_{}/{}".format(self._devices[0], param_name)
self.param_init_net.Broadcast(
inputs=[self._mpi_comm, param],
outputs=[param],
engine='MPI'
)
self._explicit_scope = False
def _AllReduceGradients(self):
self._gradient_reduce_all_added = True
if self._mpi_comm is None:
self._AllReduceGradientsSingleHost()
else:
self._AllReduceGradientsWithMPI()
def _AllReduceGradientsWithMPI(self):
self._explicit_scope = True
unique_grads_names = set(
stripParamName(grad)
for grad in self.param_to_grad.values()
)
# Step 1: sum gradients from local GPUs to master GPU
last_out = None
master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, self._devices[0])
# Note: sorted order to ensure each host puts the operators in
# same order.
for grad_name in sorted(unique_grads_names):
grads_group = [
grad
for grad in self.param_to_grad.values()
if stripParamName(grad) == grad_name
]
master_grad = "gpu_{}/{}".format(self._devices[0], grad_name)
assert master_grad in grads_group
# Remark: NCCLReduce does not support in-place modifications
# so we need a temporary gradient blob
reduced_grad = "gpu_{}/{}_red".format(
self._devices[0],
grad_name
)
with core.DeviceScope(master_device_opt):
self.ConstantFill(master_grad, reduced_grad, value=0.0)
self.net.NCCLReduce(grads_group, reduced_grad)
# Step 2: allreduce over MPI to all hosts, between master GPUs
self.net.Allreduce(
inputs=[self._mpi_comm, reduced_grad],
outputs=[master_grad],
engine='MPI',
control_input=None if last_out is None else [last_out],
)
last_out = master_grad
# Step 3: broadcast locally
self._Broadcast(self.net, grad_name)
self._explicit_scope = False
def _AllReduceGradientsSingleHost(self):
"""Performs NCCL AllReduce to distribute gradients to all the GPUs."""
if len(self._devices) == 1:
return
# Take only params that have gradient associated with them.
unique_grads_names = set(
stripParamName(grad)
for grad in self.param_to_grad.values()
)
# Now we need to Allreduce gradients on all the GPUs.
# Pick GPU #0 as a master GPU.
self._explicit_scope = True
master_device_opt = core.DeviceOption(caffe2_pb2.CUDA, self._devices[0])
with core.DeviceScope(master_device_opt):
# Group by grads for reduce.
for grad_name in unique_grads_names:
grads_group = [
grad
for grad in self.param_to_grad.values()
if stripParamName(grad) == grad_name
]
assert len(grads_group) == len(self._devices), \
"Each GPU from {}, should have a copy of {}.".format(
self._devices, grad_name)
self.NCCLAllreduce(grads_group, grads_group)
self._explicit_scope = False
def _BuildLR(self, base_lr, policy="fixed", **other_lr_params):
"""A helper to create learning rate."""
ITER = self.Iter("ITER")
# There is one interesting thing here: since we are minimizing, we are
# doing "descent" so the learning rate is set to be negative.
LR = self.net.LearningRate(
[ITER],
"LR",
base_lr=base_lr,
policy=policy,
**other_lr_params
)
return LR
def _BuildSGD(self, params, base_lr, policy="fixed", **other_lr_params):
"""A helper to construct gradient update for SGD."""
base_lr = base_lr / len(self._devices)
LR = self._BuildLR(base_lr, policy, **other_lr_params)
ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
for param in params:
grad = self.param_to_grad[param]
if isinstance(grad, core.GradientSlice):
self.ScatterWeightedSum(
[param, ONE, grad.indices, grad.values, LR], param
)
else:
self.WeightedSum([param, ONE, grad, LR], param)
def _CurrentScopeParams(self):
return [
param
for param in self.param_to_grad.keys()
if str(param).startswith(scope.NAMESCOPE)
]
def SGD(self, base_lr, policy="fixed", **other_lr_params):
"""Adds SGD optimizer to the model."""
self._AllReduceGradients()
# Create update params operators.
self._explicit_scope = True
for gpu_id in self._devices:
with core.NameScope("gpu_{}".format(gpu_id)):
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
# Only update parameters that belong to the current GPU
params = self._CurrentScopeParams()
# Add optimizer update operators
self._BuildSGD(params, base_lr, policy, **other_lr_params)
self._explicit_scope = False
def CustomSGD(
self,
paramup_build_fn,
base_lr,
lr_policy,
weight_decay,
**other_lr_pars
):
"""Custom parameter update function"""
self._AllReduceGradients()
self.AddWeightDecay(weight_decay)
# Run parameter update on each machine
self._explicit_scope = True
for gpu_id in self._devices:
with core.NameScope("gpu_{}".format(gpu_id)):
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
LR = self._BuildLR(base_lr, lr_policy, **other_lr_pars)
params = self._CurrentScopeParams()
paramup_build_fn(self, params, LR)
self._explicit_scope = False
def ExecOnEachDevice(self, fn, *args, **kwargs):
self._explicit_scope = True
for gpu_id in self._devices:
with core.NameScope("gpu_{}".format(gpu_id)):
device = core.DeviceOption(caffe2_pb2.CUDA, gpu_id)
with core.DeviceScope(device):
fn(self, *args, **kwargs)
self._explicit_scope = False
# A helper function to extract a parameter's name
def stripParamName(param):
# Format is "a/b/c/d" -> d
name = str(param)
sep = scope._NAMESCOPE_SEPARATOR
return name[name.rindex(sep) + 1:]
def SetupMPICluster(num_replicas, role, job_path):
from caffe2.python import mpi
print("Initing library")
dyndep.InitOpsLibrary('@/caffe2/caffe2/mpi:mpi_ops')
print("Setup peers")
mpi.SetupPeers(
replicas=int(num_replicas),
role=role,
job_path=job_path
)
print("Create mpi_init net")
mpi_init_net = core.Net('mpi_init')
print("Create commonworld")
mpi_comm = mpi_init_net.CreateCommonWorld(
inputs=[],
outputs=['comm_world'],
engine='MPI'
)
print("Run mpi_init net")
workspace.RunNetOnce(mpi_init_net)
print("Finished MPI setup")
return mpi_comm