| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from collections import namedtuple |
| from collections import OrderedDict |
| |
| from caffe2.proto import caffe2_pb2 |
| from collections import defaultdict |
| from caffe2.python import scope, utils, workspace |
| import numpy as np |
| |
| import caffe2.python._import_c_extension as C |
| |
| GlobalInit = C.global_init |
| |
| # Convenience redirections to functions inside scope. |
| DeviceScope = scope.DeviceScope |
| NameScope = scope.NameScope |
| |
| |
| # Bring datatype enums to the main namespace |
| class DataType: |
| pass |
| |
| |
| def _InitDataType(): |
| for name, value in caffe2_pb2.TensorProto.DataType.items(): |
| setattr(DataType, name, value) |
| |
| _InitDataType() |
| |
| # Python 2 and 3 compatibility: test if basestring exists |
| try: |
| basestring = basestring # NOQA |
| except NameError: |
| # This is python3 so we define basestring. |
| basestring = str |
| |
| |
| def _GetRegisteredOperators(): |
| return set(s.decode() for s in workspace.RegisteredOperators()) |
| |
| _REGISTERED_OPERATORS = _GetRegisteredOperators() |
| |
| |
| def RefreshRegisteredOperators(): |
| global _REGISTERED_OPERATORS |
| _REGISTERED_OPERATORS = _GetRegisteredOperators() |
| |
| |
| def IsOperator(op_type): |
| return (op_type in _REGISTERED_OPERATORS) |
| |
| |
| def IsOperatorWithEngine(op_type, engine): |
| return (op_type + "_ENGINE_" + engine in _REGISTERED_OPERATORS) |
| |
| |
| def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None): |
| option = caffe2_pb2.DeviceOption() |
| option.device_type = device_type |
| option.cuda_gpu_id = cuda_gpu_id |
| if random_seed is not None: |
| option.random_seed = random_seed |
| return option |
| |
| |
| GradientSlice = namedtuple('GradientSlice', ['indices', 'values']) |
| |
| |
| class BlobReference(object): |
| """A wrapper around a blob in a net. |
| |
| BlobReference gives us a way to refer to the network that the blob is |
| generated from. Note that blobs are, essentially, just strings in the |
| current workspace. |
| """ |
| |
| def __init__(self, name, net=None): |
| """Initializes a blob reference. |
| |
| Note that this does not prepends the namescope. If needed, use |
| ScopedBlobReference() to prepend the existing namespace. |
| """ |
| self._name = name |
| self._from_net = net |
| # meta allows helper functions to put whatever metainformation needed |
| # there. |
| self.meta = {} |
| |
| def __hash__(self): |
| return hash(self._name) |
| |
| def __eq__(self, other): |
| if isinstance(other, basestring): |
| return self._name == other |
| elif isinstance(other, BlobReference): |
| return self._name == other._name |
| else: |
| return False |
| |
| def __ne__(self, other): |
| return not(self == other) |
| |
| def __str__(self): |
| return self._name |
| |
| def __repr__(self): |
| return 'BlobReference("{}")'.format(self._name) |
| |
| def __add__(self, other): |
| if not isinstance(other, basestring): |
| raise RuntimeError('Cannot add BlobReference to a non-string.') |
| return BlobReference(self._name + other, self._from_net) |
| |
| def __radd__(self, other): |
| if not isinstance(other, basestring): |
| raise RuntimeError('Cannot add a non-string to BlobReference.') |
| return BlobReference(other + self._name, self._from_net) |
| |
| def Net(self): |
| return self._from_net |
| |
| def GetNameScope(self): |
| return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1] |
| |
| def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs): |
| """Internal function that routes the operator generation to the |
| network's __getattr__ function. |
| """ |
| inputs = [] if inputs is None else inputs |
| if isinstance(inputs, BlobReference) or isinstance(inputs, str): |
| inputs = [inputs] |
| # add self to the input list. |
| inputs.insert(0, self) |
| return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs) |
| |
| def __getattr__(self, op_type): |
| """A wrapper allowing one to initiate operators from a blob reference. |
| |
| Example: for a blob reference b that comes from network n, doing |
| b.Relu(...) |
| is equivalent to doing |
| net.Relu([b], ...) |
| """ |
| if op_type.startswith('__'): |
| raise AttributeError('Attribute {} not found.'.format(op_type)) |
| if self._from_net is None: |
| raise RuntimeError( |
| 'You cannot use a blob reference that does not have a net ' |
| 'source to create operators. Create the operator from an ' |
| 'explicit net object.') |
| if not IsOperator(op_type): |
| raise RuntimeError( |
| 'Method ' + op_type + ' is not a registered operator.' |
| ) |
| return lambda *args, **kwargs: self._CreateAndAddToNet( |
| op_type, *args, **kwargs) |
| |
| |
| def ScopedName(name): |
| """prefix the name with the current scope.""" |
| return scope.CurrentNameScope() + name |
| |
| |
| def ScopedBlobReference(name, *args, **kwargs): |
| """Returns a blob reference with scope prefixed.""" |
| return BlobReference(ScopedName(name), *args, **kwargs) |
| |
| |
| def _RectifyInputOutput(blobs, net=None): |
| """A helper function to rectify the input or output of the CreateOperator |
| interface. |
| """ |
| if isinstance(blobs, basestring): |
| # If blobs is a single string, prepend scope.CurrentNameScope() |
| # and put it as a list. |
| # TODO(jiayq): enforce using BlobReference instead of raw strings. |
| return [ScopedBlobReference(blobs, net=net)] |
| elif type(blobs) is BlobReference: |
| # If blob is a BlobReference, simply put it as a list. |
| return [blobs] |
| elif type(blobs) in (list, tuple): |
| # If blob is a list, we go through it and type check. |
| rectified = [] |
| for blob in blobs: |
| if isinstance(blob, basestring): |
| rectified.append(ScopedBlobReference(blob, net=net)) |
| elif type(blob) is BlobReference: |
| rectified.append(blob) |
| else: |
| raise TypeError( |
| "I/O blob #{} of unsupported type: {} of type {}" |
| .format(len(rectified), str(blob), type(blob))) |
| return rectified |
| else: |
| raise TypeError( |
| "Unknown input/output type: %s of type %s." % |
| (str(blobs), type(blobs)) |
| ) |
| |
| |
| def CreateOperator( |
| operator_type, |
| inputs, |
| outputs, |
| name='', |
| control_input=None, |
| device_option=None, |
| arg=None, |
| engine=None, |
| **kwargs |
| ): |
| """A function wrapper that allows one to create operators based on the |
| operator type. The type should be a string corresponding to an operator |
| registered with Caffe2. |
| """ |
| operator = caffe2_pb2.OperatorDef() |
| operator.type = operator_type |
| operator.name = name |
| # Add rectified inputs and outputs |
| inputs = _RectifyInputOutput(inputs) |
| outputs = _RectifyInputOutput(outputs) |
| operator.input.extend([str(i) for i in inputs]) |
| operator.output.extend([str(o) for o in outputs]) |
| if control_input: |
| control_input = _RectifyInputOutput(control_input) |
| operator.control_input.extend([str(i) for i in control_input]) |
| # Set device option: |
| # (1) If device_option is explicitly set, use device_option. |
| # (2) If not, but scope.CurrentDeviceScope() is set, |
| # then we use scope.CurrentDeviceScope(). |
| # (3) Otherwise, do not set device option. |
| if device_option is not None: |
| operator.device_option.CopyFrom(device_option) |
| elif scope.CurrentDeviceScope() is not None: |
| operator.device_option.CopyFrom(scope.CurrentDeviceScope()) |
| if engine is not None: |
| operator.engine = engine |
| # random seed is defined in the device option, so we need to do special |
| # care. |
| if 'random_seed' in kwargs: |
| operator.device_option.random_seed = kwargs['random_seed'] |
| del kwargs['random_seed'] |
| # Add given arguments that do not need parsing |
| if arg is not None: |
| operator.arg.extend(arg) |
| # Add all other arguments |
| for key, value in kwargs.items(): |
| operator.arg.add().CopyFrom(utils.MakeArgument(key, value)) |
| |
| if workspace.IsImmediate(): |
| workspace.RunOperatorImmediate(operator) |
| return operator |
| |
| |
| def CreatePythonOperator(f, inputs, outputs, grad_f=None, *args, **kwargs): |
| token = C.register_python_op(f) |
| if grad_f: |
| C.register_python_gradient_op(token, grad_f) |
| kwargs["token"] = token |
| return CreateOperator("Python", inputs, outputs, *args, **kwargs) |
| |
| |
| def GetIndexFromGradientList(g_list, name): |
| """A helper function to get the index from a gradient list, None if not |
| matching.""" |
| for i, g in enumerate(g_list): |
| if g == name: |
| return i |
| elif type(g) is GradientSlice: |
| if (g.indices == name or g.values == name): |
| return i |
| return None |
| |
| |
| OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions']) |
| GradGenMeta = namedtuple('GradGenMeta', ['grad_op', 'idx', 'gradient']) |
| |
| |
| class IR(object): |
| """A simple IR class to keep track of all intermediate representations used |
| in the gradient computation. |
| """ |
| |
| def __init__(self, operators): |
| # The IR class holds multiple metadata from the forward pass: |
| # a) ssa: a list of [op, in_versions, out_versions] recording the |
| # input and the output version of each operator, similar |
| # to a normal SSA form. |
| # b) input_count: a dictionary specifying for each blob and |
| # each of its version, how many times it is used as input for another |
| # op. |
| # c) frontier: maintaining the current versions of the blobs |
| # we are having in the workspace, after the execution of all the ops |
| # added to the IR so far. This is useful because if a gradient is |
| # trying to access an earlier version of a blob, we can sanity check |
| # that it is no longer there, and thus throw an error. |
| # d) gradient_frontier: maps the names of blobs to its version that the |
| # gradient corresponds to. |
| # e) gradient_generators: for each blob and each of its version, maps to |
| # a list of operators that generates its gradient together with the |
| # gradient name. |
| self.ssa = [] |
| self.input_usages = defaultdict(lambda: defaultdict(list)) |
| self.frontier = defaultdict(int) |
| self.gradient_frontier = {} |
| self.gradient_generators = defaultdict(lambda: defaultdict(list)) |
| |
| for op in operators: |
| self.Play(op) |
| |
| def Play(self, op): |
| """"Adds an op to the current IR, and update the internal states to |
| reflect the blobs and versions after the execution of the op. |
| """ |
| # For input, they are the current version in the dict. |
| in_versions = {} |
| for s in op.input: |
| in_versions[s] = self.frontier[s] |
| self.input_usages[s][self.frontier[s]].append(len(self.ssa)) |
| # For output, they are the current version plus one. If this is a |
| # newly created blob, its version starts with zero. |
| out_versions = {} |
| for s in op.output: |
| if s in self.frontier: |
| self.frontier[s] += 1 |
| out_versions[s] = self.frontier[s] |
| # Add to SSA for bookkeeping. |
| self.ssa.append(OpSSA(op, in_versions, out_versions)) |
| |
| def CheckGradientOperators( # NOQA |
| self, fwd_op_idx, gradient_ops, g_output, g_input): |
| """Checks if the gradient operators can be correctly carried out.""" |
| forward_op, in_versions, out_versions = self.ssa[fwd_op_idx] |
| locally_generated_blobs = [] |
| |
| for grad_op in gradient_ops: |
| # (1) for inputs: |
| # (1a) If it is a dense or sparse gradient name, it should match the |
| # version of the corresponding output. |
| # (1b) If it is an output name, the current version should match the |
| # version when the operator was run. |
| # (1c) If it is an input name, the current version should match the |
| # version when the operator was run. |
| # (1d) If it is none of the above, it should be a blob that is |
| # generated locally by one of the previous gradient operators. |
| for s in grad_op.input: # (1) |
| original_index = GetIndexFromGradientList(g_output, s) |
| if original_index is not None: # (1a) |
| original_name = forward_op.output[original_index] |
| if (out_versions[original_name] != |
| self.gradient_frontier[original_name]): |
| raise RuntimeError( |
| 'Gradient name "%s" is expected to correspond ' |
| 'to version %d of "%s", but currently we have ' |
| 'version %d.' % ( |
| s, out_versions[original_name], |
| original_name, |
| self.gradient_frontier[original_name])) |
| elif s in out_versions: # (1b) |
| if self.frontier[s] != out_versions[s]: |
| raise RuntimeError( |
| 'Gradient operator needs output "%s" at version' |
| ' %d, but currently we have version %d.' % ( |
| s, out_versions[s], |
| self.frontier[s] |
| ) |
| ) |
| elif s in in_versions: # (1c) |
| if (self.frontier[s] != in_versions[s]): |
| raise RuntimeError( |
| 'Gradient operator needs input "%s" at version ' |
| '%d, but currently we have version %d.' % ( |
| s, in_versions[s], |
| self.frontier[s] |
| ) |
| ) |
| else: # (1d) |
| if s not in locally_generated_blobs: |
| raise RuntimeError( |
| 'Blob name "%s" not in the scope of operator: ' |
| '%s\nand is not generated by any of the local ' |
| 'gradient operators.' % (s, str(forward_op)) |
| ) |
| # (2) for outputs: we will simply add them to locally generated |
| # blobs. We will also record the output to gradient_generators for |
| # bookkeeping, if the output corresponds to the input of a gradient. |
| for i, s in enumerate(grad_op.output): # (1) |
| locally_generated_blobs.extend(grad_op.output) |
| input_index = GetIndexFromGradientList(g_input, s) |
| if input_index is not None: |
| input_name = forward_op.input[input_index] |
| input_version = in_versions[input_name] |
| self.gradient_generators[input_name][input_version].append( |
| GradGenMeta(grad_op, i, g_input[input_index])) |
| |
| # (3) for ops (e.g., Add, Sum, Sub) which have grdient outputs directly |
| # passed from inputs (not computed from gradient ops), we create an |
| # GradGenMeta with None grad_op and idx so that the gradient_generators |
| # knows where the gradients are coming from. This is needed for creating |
| # Sum op to accumulate the gradients from multiple parents. |
| for input_index, g in enumerate(g_input): |
| if not g or str(g) in [str(b) for b in locally_generated_blobs]: |
| continue |
| input_name = forward_op.input[input_index] |
| input_version = in_versions[input_name] |
| self.gradient_generators[input_name][input_version].append( |
| GradGenMeta(None, 0, g)) |
| |
| # Finally, for the gradients specified in g_input, we update the |
| # gradient frontier to reflect the input versions that the gradients |
| # correspond to. |
| for i, g in enumerate(g_input): |
| if g is not None: |
| input_name = forward_op.input[i] |
| input_version = in_versions[input_name] |
| self.gradient_frontier[input_name] = input_version |
| |
| def _GetSumOpOutputName(self, generator, input_name): |
| sum_op_output = None |
| for grad_op, idx, _ in generator: |
| if grad_op and not sum_op_output: |
| sum_op_output = grad_op.output[idx] |
| return sum_op_output or input_name + '_grad' |
| |
| def _MakeSumOp(self, input_name, input_version): |
| generator = self.gradient_generators[input_name][input_version] |
| sum_op_input = [] |
| sum_op_output = self._GetSumOpOutputName(generator, input_name) |
| current = 0 |
| for grad_op, idx, g in generator: |
| if grad_op: |
| grad_op.output[idx] = ('_' + grad_op.output[idx] + |
| '_autosplit_{}'.format(current)) |
| sum_op_input.append(grad_op.output[idx]) |
| current += 1 |
| else: |
| if str(sum_op_output) == str(g): |
| raise RuntimeError( |
| 'The gradient output of empty gradient op can not ' |
| 'be the same as the normal name of the current ' |
| 'input gradient.') |
| sum_op_input.append(g) |
| |
| sum_op = CreateOperator("Sum", sum_op_input, sum_op_output) |
| for g in generator: |
| if g.grad_op: |
| if g.grad_op.HasField('device_option'): |
| sum_op.device_option.CopyFrom(g.grad_op.device_option) |
| break |
| |
| return sum_op |
| |
| def _VerifyGradientGenerators(self, generator): |
| # (1) check if we are dealing with dense gradients. Sparse gradients |
| # do not support automatic aggregation yet. |
| if any(type(g[2]) is GradientSlice for g in generator): |
| raise RuntimeError( |
| 'Automatic gradient aggregation does not work with sparse ' |
| 'gradients yet.') |
| |
| # If for all the operators that used the operator, none or only one |
| # produced the gradient, then no additional sum needs to be carried |
| # out. |
| if len(generator) < 2: |
| return False |
| |
| all_gradient_names = [] |
| all_device_options = [] |
| for g in generator: |
| if g.grad_op: |
| all_gradient_names.append(g.grad_op.output[g.idx]) |
| all_device_options.append(g.grad_op.device_option) |
| # Check if all grad names are the same. |
| if len(set(all_gradient_names)) > 1: |
| raise RuntimeError('Unexpected behavior: not all grad output ' |
| 'names are the same.') |
| # Check if all grad op device options are the same. |
| if len(all_device_options) >= 2 and not all( |
| d == all_device_options[0] for d in all_device_options[1:]): |
| raise RuntimeError('Unexpected behavior: not all grad ops' |
| 'have the same device option.') |
| return True |
| |
| def DoGradientAccumulation(self, fwd_op_idx): |
| """For each input name in the forward op, check if we will need to |
| add gradient accumulation. If so, do gradient accumulation and return |
| the list of gradient operators. |
| |
| The criteria for doing gradient accumulation is: |
| (1) the specific input version has been used by multiple operators. |
| (2) the current fwd_op_idx is the first to use that input, i.e. in the |
| backward pass, is the last to optionally generate the gradient for |
| the op. |
| (3) For the operators that used the input, their gradient operators |
| have generated more than 1 gradient. |
| |
| When accumulating operators, our current solution is to rename all the |
| created gradients with an internal intermediate name, and then add a |
| Sum() operator that adds up all the gradients. This may use more memory |
| due to intermediate storage, but is usually the fastest approach as one |
| can do one single sum for multiple intermediate gradients. |
| """ |
| forward_op, in_versions, out_versions = self.ssa[fwd_op_idx] |
| additional_sum_ops = [] |
| grad_map = {} |
| for i, input_name in enumerate(set(forward_op.input)): |
| input_version = in_versions[input_name] |
| input_usage = self.input_usages[input_name][input_version] |
| if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]): |
| # We do not need to do gradient accumulation yet. |
| continue |
| generator = self.gradient_generators[input_name][input_version] |
| try: |
| if not self._VerifyGradientGenerators(generator): |
| continue |
| except RuntimeError as err: |
| raise RuntimeError( |
| "Gradients for param ''{}'' failed to verity: {}".format( |
| input_name, |
| err |
| ) |
| ) |
| |
| # Finally, let's create the sum operator. |
| sum_op = self._MakeSumOp(input_name, input_version) |
| additional_sum_ops.append(sum_op) |
| grad_map[input_name] = sum_op.output[0] |
| return additional_sum_ops, grad_map |
| |
| def _GetInitGradients(self, ys): |
| input_to_grad = {} |
| gradient_ops = [] |
| for y, g in ys.items(): |
| if g is None: |
| autograd_op = CreateOperator( |
| "ConstantFill", [y], [str(y) + "_autogen_grad"], |
| value=1.0) |
| gradient_ops.append(autograd_op) |
| g = autograd_op.output[0] |
| # Since the C++ gradient registry does not have notion of |
| # NameScopes, we will convert all references to strings. |
| input_to_grad[str(y)] = ( |
| GradientSlice(str(g[0]), str(g[1])) |
| if isinstance(g, GradientSlice) else str(g)) |
| |
| return input_to_grad, gradient_ops |
| |
| def _GenerateGradientsForForwardOp( |
| self, forward_op_idx, input_to_grad): |
| new_input_to_grad = {} |
| gradient_ops = [] |
| forward_op, in_versions, out_versions = self.ssa[forward_op_idx] |
| g_output = list( |
| input_to_grad.get(name, None) for name in forward_op.output) |
| if not all(g is None for g in g_output): |
| gradient_ops, g_input = GradientRegistry.GetGradientForOp( |
| forward_op, g_output) |
| # Checks if the gradient operators are legal |
| self.CheckGradientOperators( |
| forward_op_idx, gradient_ops, g_output, g_input) |
| # Record the gradient map to all_input_to_grad. |
| for name, grad in zip(forward_op.input, g_input): |
| new_input_to_grad[name] = grad |
| |
| return new_input_to_grad, gradient_ops |
| |
| def GetBackwardPass(self, ys): |
| """Gets the backward pass that computes the derivatives of given blobs. |
| |
| Inputs: |
| ys: a list or a dictionary specifying what blobs we want to compute |
| derivatives of. If the input is a list, we will automatically |
| generate their gradients with all-one values; if the input is a |
| dictionary, for any dictionary entries that are not None, we will |
| take the corresponding blobs as their gradients; for all those |
| that are None, we will auto-fill them with 1. |
| """ |
| if isinstance(ys, list): |
| ys = dict((y, None) for y in ys) |
| elif not isinstance(ys, dict): |
| raise TypeError("ys should either be a list or a dict.") |
| |
| # Set the gradient frontier with the initialized external |
| # gradients. |
| for y, _ in ys.items(): |
| self.gradient_frontier[y] = self.frontier[y] |
| |
| all_input_to_grad, all_gradient_ops = self._GetInitGradients(ys) |
| |
| # (2) Now, after having the virtual play above, we now play the ops |
| # backwards, creating the gradients along the path. Note that although |
| # we are playing it backwards, we cannot refer to variables that are |
| # at a version older than current_versions because it is already been |
| # overwritten. |
| for forward_op_idx in reversed(range(len(self.ssa))): |
| input_to_grad, gradient_ops = self._GenerateGradientsForForwardOp( |
| forward_op_idx, all_input_to_grad) |
| all_input_to_grad.update(input_to_grad) |
| all_gradient_ops += gradient_ops |
| |
| # If there are multiple use blobs, do gradient accumulation. |
| additional_sum_ops, grad_map = self.DoGradientAccumulation( |
| forward_op_idx) |
| # This line is so that if in an accumulation some of the operators |
| # have not produced gradients, they still do not overwrite the |
| # general all_input_to_grad map. |
| all_input_to_grad.update(grad_map) |
| all_gradient_ops += additional_sum_ops |
| |
| # (3) Post-processing. |
| # After we have done computation for each op, we now have the gradient |
| # operators ready. For the output map, we will convert everything to |
| # BlobReferences for easier handling in python. |
| all_input_to_grad_out = {} |
| for key, val in all_input_to_grad.items(): |
| if val is not None: |
| all_input_to_grad_out[BlobReference(key)] = ( |
| BlobReference(val) if isinstance(val, basestring) else |
| GradientSlice(BlobReference(val[0]), BlobReference(val[1]))) |
| return all_gradient_ops, all_input_to_grad_out |
| |
| |
| class GradientRegistry(object): |
| """GradientRegistry holds the mapping from operators to their gradients.""" |
| gradient_registry_ = {} |
| |
| @classmethod |
| def RegisterGradient(cls, op_type): |
| """A decorator for registering gradient mappings.""" |
| |
| def Wrapper(func): |
| cls.gradient_registry_[op_type] = func |
| return func |
| |
| return Wrapper |
| |
| @classmethod |
| def _GetGradientForOpCC(cls, op_def, g_output): |
| # TODO(tulloch) - Propagate GradientWrapper up through the stack. |
| def from_untyped(grad): |
| if grad is None: |
| w = C.GradientWrapper() |
| assert w.is_empty() |
| return w |
| try: |
| (indices, values) = grad |
| w = C.GradientWrapper() |
| w.indices = indices |
| w.values = values |
| assert w.is_sparse() |
| return w |
| except ValueError: |
| w = C.GradientWrapper() |
| w.dense = grad |
| assert w.is_dense() |
| return w |
| |
| g_output = [from_untyped(grad) for grad in g_output] |
| grad_defs_str, g_input = C.get_gradient_defs( |
| op_def.SerializeToString(), g_output) |
| |
| def to_untyped(grad_wrapper): |
| if grad_wrapper.is_empty(): |
| return None |
| if grad_wrapper.is_sparse(): |
| return GradientSlice(grad_wrapper.indices, grad_wrapper.values) |
| assert grad_wrapper.is_dense() |
| return grad_wrapper.dense |
| |
| g_input = [to_untyped(grad_wrapper) for grad_wrapper in g_input] |
| grad_defs = [] |
| for grad_def_str in grad_defs_str: |
| grad_def = caffe2_pb2.OperatorDef() |
| grad_def.ParseFromString(grad_def_str) |
| grad_defs.append(grad_def) |
| return grad_defs, g_input |
| |
| @classmethod |
| def GetGradientForOp(cls, op, g_output): |
| try: |
| gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output) |
| except Exception as e: |
| # Not supported in C++; will try python registration next. |
| |
| try: |
| gradient_ops, g_input = cls.gradient_registry_[op.type]( |
| op, g_output) |
| except KeyError: |
| raise Exception( |
| "No gradient registered for {}. ".format(op.type) + |
| "Exception from creating the gradient op: {}.".format(e)) |
| |
| if gradient_ops is None: |
| return [], g_input |
| if type(gradient_ops) is not list: |
| gradient_ops = [gradient_ops] |
| return gradient_ops, g_input |
| |
| @classmethod |
| def GetBackwardPass(cls, operators, ys): |
| """Gets the backward pass for the list of operators. |
| |
| Args: |
| operators: a list of operators constituting the forward pass. |
| ys: a list or a dictionary specifying what blobs we want to compute |
| derivatives of. If the input is a list, we will automatically |
| generate their gradients with all-one values; if the input is a |
| dictionary, for any dictionary entries that are not None, we'll |
| take the corresponding blobs as their gradients; for all those |
| that are None, we will auto-fill them with 1. |
| Returns: |
| gradient_ops: a list of gradient operators to run. |
| all_input_to_grads: a map from input to their corresponding |
| gradients. |
| """ |
| ir = IR(operators) |
| return ir.GetBackwardPass(ys) |
| |
| |
| def get_ssa(net, blob_versions=None): |
| """ |
| Given a net, return a structure containing the version of each input and |
| output blob used by each operator. |
| |
| Args: |
| net: either a Net or a NetDef |
| blob_versions: (optional) map with current version number for given |
| blob names. If not provided or blob not found, start |
| from version 0. |
| Returns: |
| Tuple (ssa, blob_versions) |
| ssa: list of tuples (versioned_inputs, versioned_outputs) |
| for each op in the net. A versioned input is a tuple |
| (blob_name, version). |
| blob_versions: updated map with latest version of each blob found in |
| the net. |
| """ |
| proto = net.Proto() if isinstance(net, Net) else net |
| assert isinstance(proto, caffe2_pb2.NetDef) |
| if blob_versions is None: |
| blob_versions = {} |
| if isinstance(net, list): |
| return [get_ssa(n, blob_versions) for n in net], blob_versions |
| for i in proto.external_input: |
| if i not in blob_versions: |
| blob_versions[str(i)] = 0 |
| ssa = [] |
| for op in proto.op: |
| if not proto.external_input: |
| for i in op.input: |
| if i not in blob_versions: |
| blob_versions[i] = 0 |
| inputs = [(str(i), blob_versions.get(str(i), 0)) for i in op.input] |
| for o in op.output: |
| blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1 |
| outputs = [(str(o), blob_versions[str(o)]) for o in op.output] |
| ssa.append((inputs, outputs)) |
| return ssa, blob_versions |
| |
| |
| def get_undefined_blobs(ssa): |
| """ |
| Given a ssa in the format produced by get_ssa(), return a set of blobs that |
| are used before they are defined, which corresponds to inputs at version 0. |
| """ |
| undef_blobs = set() |
| for inputs, outputs in ssa: |
| undef_blobs |= set(name for (name, ver) in inputs if ver == 0) |
| return undef_blobs |
| |
| |
| def get_output_producers(ssa): |
| """ |
| Given a ssa in the format produced by get_ssa(), returns a map from |
| versioned blob into the operator index that produces that version of |
| the blob. A versioned blob is a tuple (blob_name, version). |
| """ |
| producers = {} |
| for i, (inputs, outputs) in enumerate(ssa): |
| for o in outputs: |
| producers[o] = i |
| return producers |
| |
| |
| def get_op_ids_in_path(ssa, blob_versions, inputs, outputs): |
| """ |
| Given a ssa and blob_versions as produced by get_ssa(), returns the list |
| of op indices that are necessary in order to generate the blobs in |
| `outputs`, given blobs in `inputs`. |
| Consider that the `inputs` are given in their latest version. |
| """ |
| inputs_set = set((str(i), blob_versions[str(i)]) for i in inputs) |
| producers = get_output_producers(ssa) |
| queue = [(str(o), blob_versions[str(o)]) for o in outputs] |
| used_op_ids = set() |
| while len(queue) > 0: |
| o = queue.pop() |
| if (o not in inputs_set) and (o in producers): |
| op_id = producers[o] |
| used_op_ids |= {op_id} |
| inputs, _ = ssa[op_id] |
| queue.extend(inputs) |
| return sorted(used_op_ids) |
| |
| |
| def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None): |
| """ |
| Clone the given Net, binding its input schema to the given `inputs` record. |
| Blob names defined by the net are prepended with the given `prefix`. |
| |
| Args: |
| net: the net to clone |
| name: the name of the new net |
| prefix: the prefix to append to local blobs |
| blob_remap: (optional) dict with additional blob name remapping. |
| inputs: (optional) input record that will provide actual input |
| values for the cloned net. Must be compatible with the |
| net's input schema. |
| Returns: |
| Tuple (cloned_net, blob_remap) |
| clone_net: the cloned Net |
| blob_remap: a map from original blob names into remapped blob names |
| """ |
| from caffe2.python import schema |
| assert isinstance(net, Net) |
| if blob_remap is None: |
| blob_remap = {} |
| if inputs is not None: |
| assert isinstance(inputs, schema.Field) |
| original = net.input_record() |
| assert original is not None |
| # TODO(azzolini): improve schema type checking |
| assert set(original.field_names()) == set(inputs.field_names()), ( |
| 'Schemas do not match.') |
| original_mapping = dict(zip(original.field_names(), |
| original.field_blobs())) |
| for a, b in zip(inputs.field_names(), inputs.field_blobs()): |
| blob_remap[str(original_mapping[a])] = str(b) |
| proto = net.Proto() |
| ssa, blob_versions = get_ssa(proto) |
| undef_blobs = get_undefined_blobs(ssa) |
| |
| for blob in blob_versions.keys(): |
| if blob in blob_remap: |
| continue |
| elif blob in undef_blobs: |
| blob_remap[blob] = blob |
| else: |
| blob_remap[blob] = prefix + blob |
| return net.Clone(name, blob_remap), blob_remap |
| |
| |
| def _get_blob_ref(blob_name_or_ref): |
| return ( |
| blob_name_or_ref if isinstance(input, BlobReference) |
| else BlobReference(blob_name_or_ref) |
| ) |
| |
| class Net(object): |
| _net_names_used = set() |
| operator_registry_ = {} |
| |
| @staticmethod |
| def _get_next_net_name(basename): |
| name = basename |
| next_idx = 1 |
| while name in Net._net_names_used: |
| name = basename + '_' + str(next_idx) |
| next_idx += 1 |
| Net._net_names_used |= set([name]) |
| return name |
| |
| def __init__(self, name_or_proto): |
| """ |
| Create a Net. |
| Args: |
| name_or_proto: If a NetDef is provided, clone it. Otherwise, |
| create an empty net with the given name. |
| """ |
| self._input_record = None |
| self._output_record = None |
| self._attr_dict = defaultdict(list) |
| if type(name_or_proto) is caffe2_pb2.NetDef: |
| proto = name_or_proto |
| # We rae initializing a network by a NetDef. In this case, we will |
| # initialize our network with the given netdef. |
| self._net = caffe2_pb2.NetDef() |
| self._net.CopyFrom(proto) |
| # Set the next name index properly. |
| existing_names = set( |
| sum( |
| [list(op.input) for op in self._net.op], [] |
| ) + sum( |
| [list(op.output) for op in self._net.op], [] |
| ) |
| ) |
| prefix_len = len(self._net.name + '_blob_') |
| autogen_indices = [] |
| for s in existing_names: |
| if s.startswith(self._net.name + '_blob_'): |
| try: |
| autogen_indices.append(int(s[prefix_len])) |
| except ValueError: |
| pass |
| if len(autogen_indices): |
| self._next_name_index = max(autogen_indices) + 1 |
| else: |
| self._next_name_index = 0 |
| else: |
| self._net = caffe2_pb2.NetDef() |
| self._net.name = name_or_proto |
| self._next_name_index = 0 |
| |
| # make sure that this net name hasn't been used before |
| self._net.name = Net._get_next_net_name(self._net.name) |
| |
| def AppendNet(self, net): |
| assert isinstance(net, Net) |
| self.Proto().op.extend(net.Proto().op) |
| self.Proto().external_input.extend( |
| [i for i in net.Proto().external_input |
| if i not in self.Proto().external_input]) |
| self.Proto().external_output.extend( |
| [o for o in net.Proto().external_output |
| if o not in self.Proto().external_output]) |
| return self |
| |
| def LogInfo(self, *msg_or_blobs): |
| for msg_or_blob in msg_or_blobs: |
| if not isinstance(msg_or_blob, BlobReference): |
| blob = self.GivenTensorStringFill( |
| [], self.NextName('log'), |
| shape=[], values=[msg_or_blob]) |
| else: |
| blob = msg_or_blob |
| self.Print(blob, []) |
| |
| def add_attribute(self, name, obj): |
| """ |
| Add `obj` to the list of attributes in this net under the given `name`. |
| Attributes are user-defined objects and have no pre-defined semantics. |
| """ |
| self._attr_dict[name].append(obj) |
| |
| def get_attributes(self, name): |
| """ |
| Returns the list of attributes in this net for a given `name`. |
| Attributes are user-defined objects added with `add_attribute'. |
| """ |
| return self._attr_dict.get(name, []) |
| |
| def Name(self): |
| return self._net.name |
| |
| def __str__(self): |
| return self.Name() |
| |
| def Const(self, array, blob_out=None, dtype=None): |
| if isinstance(array, bool): |
| return self.ConstantFill( |
| [], |
| blob_out or 1, |
| dtype=DataType.BOOL, |
| value=array) |
| |
| if dtype is None: |
| array = np.array(array) |
| else: |
| array = np.array(array, dtype=dtype) |
| |
| def do_set(operator): |
| return operator( |
| [], |
| blob_out or 1, |
| shape=array.shape, |
| values=array.flatten().tolist()) |
| |
| if array.dtype == np.int32: |
| return do_set(self.GivenTensorIntFill) |
| elif array.dtype == np.int64: |
| return do_set(self.GivenTensorInt64Fill) |
| elif array.dtype == np.str: |
| return do_set(self.GivenTensorStringFill) |
| else: |
| return do_set(self.GivenTensorFill) |
| |
| def BlobIsDefined(self, blob): |
| """ |
| Returns true if the given BlobReference is produced as output of |
| an operator in this net, or if it is provided as an external input. |
| """ |
| blob_name = str(blob) |
| for input in self._net.external_input: |
| if input == blob_name: |
| return True |
| for op in self._net.op: |
| for output in op.output: |
| if output == blob_name: |
| return True |
| return False |
| |
| def UsesBlob(self, blob): |
| """ |
| Returns true iff the given BlobReference is used by any operator |
| or this net, or if it is one of the external inputs of the net. |
| """ |
| blob_name = str(blob) |
| for op in self._net.op: |
| for input in op.input: |
| if input == blob_name: |
| return True |
| for input in self._net.external_input: |
| if input == blob_name: |
| return True |
| return False |
| |
| def GetBlobRef(self, blob_name): |
| """ |
| Given the name of a blob produced by this net, return a BlobReference |
| to it. If the blob is not produced by any op in this net, |
| raises KeyError. |
| """ |
| blob_name = str(blob_name) |
| if not self.BlobIsDefined(blob_name): |
| raise KeyError('Net does not define blob %s' % blob_name) |
| return BlobReference(blob_name, self) |
| |
| def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None): |
| """ |
| Clone this net. |
| Args: |
| name: name of the cloned net |
| blob_remap: optional map with list of blob names to replace |
| op_id_mask: optional list of operator indices to include in |
| the cloned net. If not provided, all ops are included. |
| """ |
| if remap_funcs is None: |
| remap_funcs = {} |
| proto = self._net |
| new_proto = caffe2_pb2.NetDef() |
| new_proto.CopyFrom(proto) |
| new_proto.name = name |
| if blob_remap is None and op_id_mask is None: |
| return Net(new_proto) |
| |
| if blob_remap is None: |
| blob_remap = {} |
| if op_id_mask is None: |
| op_id_mask = range(0, len(proto.op)) |
| |
| def remap_list(proto_list): |
| new_list = [blob_remap.get(b, b) for b in proto_list] |
| del proto_list[:] |
| proto_list.extend(new_list) |
| |
| def remap_op(op): |
| new_op = caffe2_pb2.OperatorDef() |
| new_op.CopyFrom(op) |
| remap_list(new_op.input) |
| remap_list(new_op.output) |
| if new_op.type in remap_funcs: |
| remap_funcs[new_op.type](new_op, (name + '/') if name else '') |
| return new_op |
| |
| del new_proto.op[:] |
| new_proto.op.extend(remap_op(proto.op[op_id]) for op_id in op_id_mask) |
| remap_list(new_proto.external_input) |
| remap_list(new_proto.external_output) |
| new_net = Net(new_proto) |
| |
| from caffe2.python import schema |
| if self._input_record: |
| new_net._input_record = schema.from_blob_list( |
| self._input_record, |
| [ |
| BlobReference(str(blob_remap[str(blob)]), net=new_net) |
| for blob in self._input_record.field_blobs() |
| ], |
| ) |
| if self._output_record: |
| new_net._output_record = schema.from_blob_list( |
| self._output_record, |
| [ |
| BlobReference(str(blob_remap[str(blob)]), net=new_net) |
| for blob in self._output_record.field_blobs() |
| ], |
| ) |
| new_net._attr_dict.update(self._attr_dict) |
| return new_net |
| |
| def ClonePartial(self, name, inputs, outputs, remap_funcs=None): |
| """ |
| Clone this net, including only ops that are necessary in order to |
| compute `outputs` given `inputs`. Return references to the cloned |
| outputs. Internal blobs (blobs that are produced and consumed inside |
| the net but not used as outputs) will be remapped to avoid name |
| conflict. |
| |
| Args: |
| name: the name of the cloned net |
| inputs: map where the keys correspond to BlobReferences in the |
| original net, and the values correspond to external inputs |
| in the partially cloned net. If `inputs` is a list, don't |
| remap input names. |
| outputs: outputs to be produced by the cloned net. |
| |
| Returns: |
| Tuple (new_net, new_outputs) |
| new_net: a new Net object. |
| new_outputs: list of BlobReferences corresponding to the |
| outputs produced by new_net. |
| """ |
| input_is_pair_list = isinstance(inputs, list) and all( |
| isinstance(i, tuple) and len(i) == 2 for i in inputs) |
| inputs = ( |
| inputs if isinstance(inputs, (dict, OrderedDict)) else |
| OrderedDict(inputs) if input_is_pair_list else |
| OrderedDict(zip(inputs, inputs))) |
| for output in outputs: |
| assert self.BlobIsDefined(output) |
| input_names = {str(k): str(v) for k, v in inputs.items()} |
| output_names = [str(o) for o in outputs] |
| proto = self._net |
| ssa, blob_versions = get_ssa(proto) |
| used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs) |
| disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs) |
| assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, ( |
| 'Cannot partially clone net: some of the ops required would ' + |
| 'generate the given input.') |
| |
| sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids] |
| undef_blobs = get_undefined_blobs(sub_ssa) - set(input_names.keys()) |
| prefix = (name + '/') if name else '' |
| |
| def remap(blob_name): |
| if blob_name in input_names: |
| return input_names[blob_name] |
| elif blob_name in undef_blobs: |
| return blob_name |
| else: |
| return prefix + blob_name |
| |
| blob_mapping = {b: remap(b) for b in blob_versions.keys()} |
| new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs) |
| new_in = [ |
| blob_mapping[i] for i in input_names.keys()] + list(undef_blobs) |
| new_out = [blob_mapping[o] for o in output_names] |
| del new_net.Proto().external_input[:] |
| new_net.Proto().external_input.extend(new_in) |
| del new_net.Proto().external_output[:] |
| new_net.Proto().external_output.extend(new_out) |
| return new_net, [new_net.GetBlobRef(o) for o in new_out] |
| |
| def Proto(self): |
| return self._net |
| |
| def NextName(self, prefix=None, output_id=None): |
| """Returns the next name to be used, if you do not want to explicitly |
| name your blob.""" |
| if prefix: |
| output_name_base = self._net.name + '/' + prefix |
| output_name = output_name_base |
| if output_id is not None: |
| output_name += ':' + str(output_id) |
| index = 2 |
| while self.BlobIsDefined(str(ScopedBlobReference(output_name))): |
| output_name = output_name_base + '_' + str(index) |
| if output_id is not None: |
| output_name += ':' + str(output_id) |
| index += 1 |
| else: |
| output_name = self._net.name + '_blob_' + str(self._next_name_index) |
| self._next_name_index += 1 |
| return str(output_name) |
| |
| def AddGradientOperators(self, ys, skip=0): |
| """Add the gradient for operators in the net. |
| |
| Inputs: |
| ys: a list or a dictionary specifying what blobs we want to compute |
| derivatives of. If the input is a list, we will automatically |
| generate their gradients with all-one values; if the input is a |
| dictionary, for any dictionary entries that are not None, we will |
| take the corresponding blobs as their gradients; for all those |
| that are None, we will auto-fill them with 1. |
| skip: skips the first n operators. This is provided mainly because a |
| lot of nets may use the first few operators for data generation |
| like stuff which really do not need to have gradients. |
| |
| Outputs: |
| returns a map from the blob name in the input network to a blob |
| containing gradient or a GradientSlice in case of sparse gradient |
| |
| Currently, this is hard-coded for float operators if there are branches |
| (i.e. a blob is used as input to multiple operators). This is because |
| the gradient accumulation (Sum) is float only right now. |
| """ |
| |
| grad_ops, input_to_grad = GradientRegistry.GetBackwardPass( |
| self._net.op[skip:], ys) |
| # Check if in immediate mode: the grad_ops are actually being produced |
| # by C++ and bypasses the CreateOperator() call, so in immediate mode |
| # we will have to explicitly run them. |
| if workspace.IsImmediate(): |
| for op in grad_ops: |
| workspace.RunOperatorImmediate(op) |
| self._net.op.extend(grad_ops) |
| return input_to_grad |
| |
| def AddExternalInput(self, input): |
| input_name = str(input) |
| assert input_name not in self._net.external_input, ( |
| 'Net already contains an input named %s' % input_name) |
| self._net.external_input.extend([input_name]) |
| return _get_blob_ref(input_name) |
| |
| def AddExternalOutput(self, output): |
| assert isinstance(output, BlobReference) |
| assert self.BlobIsDefined(output) |
| self.Proto().external_output.extend([str(output)]) |
| return output |
| |
| @property |
| def external_inputs(self): |
| return map(_get_blob_ref, self._net.external_input) |
| |
| @property |
| def external_outputs(self): |
| return map(_get_blob_ref, self._net.external_output) |
| |
| def set_input_record(self, input_record): |
| from caffe2.python import schema |
| assert self._input_record is None, ( |
| 'Input schema cannot be reset') |
| if not input_record.has_blobs(): |
| self._input_record = schema.NewRecord(self, input_record) |
| else: |
| self._input_record = input_record |
| for blob in input_record.field_blobs(): |
| if blob not in self.external_inputs: |
| self.AddExternalInput(blob) |
| return self._input_record |
| |
| def set_output_record(self, record): |
| assert self._output_record is None, ( |
| 'Output record cannot be reset') |
| for blob in record.field_blobs(): |
| assert self.BlobIsDefined(blob) |
| for blob in record.field_blobs(): |
| self.AddExternalOutput(blob) |
| self._output_record = record |
| |
| def input_record(self): |
| return self._input_record |
| |
| def output_record(self): |
| return self._output_record |
| |
| def DeduplicateGradientSlices(self, g): |
| assert isinstance(g, GradientSlice) |
| unique, remapping = self.Unique([g.indices], 2, engine='SparseHash') |
| sum_g = self.UnsortedSegmentSum([g.values, remapping], 1) |
| return GradientSlice(indices=unique, values=sum_g) |
| |
| def RunAllOnGPU(self, gpu_id=0, use_cudnn=False): |
| """A convenient function to run everything on the GPU.""" |
| device_option = caffe2_pb2.DeviceOption() |
| device_option.device_type = caffe2_pb2.CUDA |
| device_option.cuda_gpu_id = gpu_id |
| self._net.device_option.CopyFrom(device_option) |
| if use_cudnn: |
| for op in self._net.op: |
| op.engine = "CUDNN" |
| |
| def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs): |
| """A helper function to create an operator and add it to self. |
| """ |
| inputs = _RectifyInputOutput(inputs) |
| for input in inputs: |
| if not self.BlobIsDefined(input): |
| assert input.Net() != self |
| self.AddExternalInput(input) |
| if outputs is None: |
| # If we do not specify an output, we will assume that this op |
| # produces one output in this case. |
| outputs = self.NextName(prefix=op_type) |
| elif type(outputs) is int: |
| # In this case, we will auto-fill the given number of outputs |
| # with auto-generated names. |
| outputs = [ |
| self.NextName(prefix=op_type, output_id=i) |
| for i in range(outputs)] |
| outputs = _RectifyInputOutput(outputs, net=self) |
| op = CreateOperator(op_type, inputs, outputs, **kwargs) |
| self._net.op.extend([op]) |
| if len(op.output) == 0: |
| return |
| elif len(op.output) == 1: |
| return BlobReference(str(op.output[0]), self) |
| else: |
| return tuple(BlobReference(str(o), self) for o in op.output) |
| |
| def __getattr__(self, op_type): |
| if op_type.startswith('__'): |
| raise AttributeError('Attribute {} not found.'.format(op_type)) |
| if not IsOperator(op_type): |
| raise RuntimeError( |
| 'Method ' + op_type + ' is not a registered operator.' |
| ) |
| return lambda *args, **kwargs: self._CreateAndAddToSelf( |
| op_type, *args, **kwargs) |
| |
| def Python(self, f, grad_f=None): |
| assert(IsOperator('Python')) |
| token = C.register_python_op(f) |
| if grad_f: |
| C.register_python_gradient_op(token, grad_f) |
| return lambda *args, **kwargs: self._CreateAndAddToSelf( |
| 'Python', token=token, *args, **kwargs) |
| |
| |
| def get_net_name(netlike): |
| if isinstance(netlike, Net): |
| return netlike.Proto().name |
| elif isinstance(netlike, caffe2_pb2.NetDef): |
| return netlike.name |
| else: |
| return netlike |
| |
| |
| def output_to_list(op_output): |
| """ |
| Ensures that the output of an operator is a list. |
| Use when an operator has a variable number of outputs, but a list of |
| outputs is desired even when number of outputs is 1. |
| |
| Args: |
| op_output: Either a BlobReferenece or an iterable of BlobReferences. |
| |
| Returns: |
| A list of BlobReferences. |
| """ |
| assert type(op_output) in (list, tuple, BlobReference) |
| return ( |
| [op_output] |
| if isinstance(op_output, BlobReference) else list(op_output)) |
| |
| |
| def _add_net_to_dict(net_dict, net): |
| name = get_net_name(net) |
| if net in net_dict: |
| assert net_dict[name] is None or net == net_dict[name], ( |
| 'Different nets with same name: ' + name) |
| return False |
| else: |
| net_dict[name] = net if isinstance(net, Net) else None |
| return True |
| |
| |
| class ExecutionStep(object): |
| _step_names_used = set() |
| |
| @staticmethod |
| def _get_next_step_name(basename): |
| name = basename |
| next_idx = 1 |
| while name in ExecutionStep._step_names_used: |
| name = basename + '_' + str(next_idx) |
| next_idx += 1 |
| ExecutionStep._step_names_used |= set([name]) |
| return name |
| |
| def __init__(self, name, nets=None, num_iter=None): |
| self._step = caffe2_pb2.ExecutionStep() |
| self._step.name = name or ExecutionStep._get_next_step_name('step') |
| self._net_dict = OrderedDict() |
| self._is_used = False |
| self._substeps = [] |
| if nets is not None: |
| if type(nets) is Net: |
| nets = [nets] |
| for net in nets: |
| if _add_net_to_dict(self._net_dict, net): |
| self._step.network.extend([get_net_name(net)]) |
| if num_iter is not None: |
| self._step.num_iter = num_iter |
| |
| def get_net(self, name): |
| return self._net_dict[name] |
| |
| def Name(self): |
| return self._step.name |
| |
| def __str__(self): |
| return self._step.name |
| |
| def _assert_can_mutate(self): |
| assert not self._is_used, ( |
| 'Cannot mutate a step that has already been added to a plan/step.') |
| |
| def _notify_is_used(self): |
| self._is_used = True |
| |
| def Proto(self): |
| return self._step |
| |
| def HasNets(self): |
| return self._step.network is not None and ( |
| len(self._step.network) > 0) |
| |
| def HasSubsteps(self): |
| return self._step.substep is not None and ( |
| len(self._step.substep) > 0) |
| |
| def Nets(self): |
| return self._net_dict.values() |
| |
| def Substeps(self): |
| return self._substeps |
| |
| def SetIter(self, num_iter): |
| self._assert_can_mutate() |
| self._step.num_iter = num_iter |
| |
| def SetOnlyOnce(self, only_once): |
| self._assert_can_mutate() |
| self._step.only_once = only_once |
| |
| def SetShouldStopBlob(self, should_stop_blob): |
| assert isinstance(should_stop_blob, BlobReference), ( |
| "expects BlobReference here, got {}".format(type(should_stop_blob))) |
| self._assert_can_mutate() |
| self._step.should_stop_blob = str(should_stop_blob) |
| |
| def SetReportNet(self, report_net, report_interval): |
| self._assert_can_mutate() |
| _add_net_to_dict(self._net_dict, report_net) |
| self._step.report_net = get_net_name(report_net) |
| self._step.report_interval = report_interval |
| |
| def AddSubstep(self, substep): |
| self._assert_can_mutate() |
| assert not self.HasNets(), 'Cannot have both network and substeps.' |
| if isinstance(substep, ExecutionStep): |
| substep._notify_is_used() |
| if not substep.HasNets() and not substep.HasSubsteps(): |
| return self |
| for net in substep.Nets(): |
| _add_net_to_dict(self._net_dict, net) |
| self._substeps.append(substep) |
| proto = substep.Proto() |
| else: |
| proto = substep |
| self._step.substep.add().CopyFrom(proto) |
| return self |
| |
| def SetConcurrentSubsteps(self, concurrent_substeps): |
| self._assert_can_mutate() |
| assert not self.HasNets(), 'Cannot have both network and substeps.' |
| self._step.concurrent_substeps = concurrent_substeps |
| |
| def AddNet(self, net): |
| self._assert_can_mutate() |
| assert not self.HasSubsteps(), 'Cannot have both network and substeps.' |
| assert isinstance(net, Net) |
| _add_net_to_dict(self._net_dict, net) |
| self._step.network.extend([get_net_name(net)]) |
| return self |
| |
| def get_all_attributes(self, name): |
| """ |
| Return the list of all attributes under the given `name`, present in |
| all of the nets used in this execution step and its children. |
| """ |
| objs = [] |
| for net in self._net_dict.values(): |
| objs += net.get_attributes(name) |
| return objs |
| |
| |
| def add_nets_in_order(step, net_list): |
| proto = step.Proto() |
| for substep in step.Substeps(): |
| add_nets_in_order(substep, net_list) |
| for net in proto.network: |
| if net not in net_list: |
| net_list.append(net) |
| # FIXME(azzolini): This is actually wrong. Report nets should be |
| # instantiated first since they may run before any substep is run. |
| # However, curerntly, Reporter depends on this behavior. |
| if proto.report_net and proto.report_net not in net_list: |
| net_list.append(proto.report_net) |
| |
| |
| class Plan(object): |
| def __init__(self, name_or_step): |
| self._plan = caffe2_pb2.PlanDef() |
| self._net_dict = OrderedDict() |
| if isinstance(name_or_step, ExecutionStep): |
| self._plan.name = name_or_step.Name() |
| self.AddStep(name_or_step) |
| elif isinstance(name_or_step, basestring): |
| self._plan.name = name_or_step |
| else: |
| raise ValueError('name_or_step must be a string or ExecutionStep') |
| |
| def __str__(self): |
| return self._plan.name |
| |
| def Proto(self): |
| return self._plan |
| |
| def AddNets(self, nets): |
| for net in nets: |
| if _add_net_to_dict(self._net_dict, net): |
| assert isinstance(net, Net) |
| self._plan.network.add().CopyFrom(net.Proto()) |
| |
| def Nets(self): |
| return self._net_dict.values() |
| |
| def AddStep(self, step): |
| assert isinstance(step, ExecutionStep) |
| step._notify_is_used() |
| if not step.HasNets() and not step.HasSubsteps(): |
| return |
| self._plan.execution_step.add().CopyFrom(step.Proto()) |
| # nets need to be added to the plan in order of usage |
| net_list = [] |
| add_nets_in_order(step, net_list) |
| self.AddNets([step.get_net(n) for n in net_list]) |
| |
| def get_all_attributes(self, name): |
| """ |
| Return the list of all attributes under the given `name`, present in |
| all of the nets used in this plan. |
| """ |
| objs = [] |
| for net in self._net_dict.values(): |
| objs += net.get_attributes(name) |
| return objs |
| |
| |
| def to_execution_step(step_or_nets, default_name=None): |
| from caffe2.python.net_builder import NetBuilder |
| if isinstance(step_or_nets, ExecutionStep): |
| return step_or_nets |
| |
| stop_blob = None |
| if isinstance(step_or_nets, NetBuilder): |
| stop_blob = step_or_nets._stop_blob |
| step_or_nets = step_or_nets.get() |
| return execution_step( |
| default_name, step_or_nets, should_stop_blob=stop_blob) |
| |
| |
| def execution_step(default_name, |
| steps_or_nets, |
| num_iter=None, |
| report_net=None, |
| report_interval=None, |
| concurrent_substeps=None, |
| should_stop_blob=None, |
| only_once=None): |
| """ |
| Helper for creating an ExecutionStep. |
| - steps_or_nets can be: |
| - None |
| - Net |
| - ExecutionStep |
| - list<Net> |
| - list<ExecutionStep> |
| - should_stop_blob is either None or a scalar boolean blob. |
| - This blob is checked AFTER every substeps/subnets. |
| - If specified and true, then this step will return immediately. |
| - Be sure to handle race conditions if setting from concurrent threads. |
| - if no should_stop_blob or num_iter is provided, defaults to num_iter=1 |
| """ |
| assert should_stop_blob is None or num_iter is None, ( |
| 'Cannot set both should_stop_blob and num_iter.') |
| if should_stop_blob is None and num_iter is None: |
| num_iter = 1 |
| |
| step = ExecutionStep(default_name) |
| if should_stop_blob is not None: |
| step.SetShouldStopBlob(should_stop_blob) |
| if num_iter is not None: |
| step.SetIter(num_iter) |
| if only_once is not None: |
| step.SetOnlyOnce(only_once) |
| if concurrent_substeps is not None: |
| step.SetConcurrentSubsteps(concurrent_substeps) |
| if report_net is not None: |
| assert report_interval is not None |
| step.SetReportNet(report_net, report_interval) |
| |
| if isinstance(steps_or_nets, ExecutionStep): |
| step.AddSubstep(steps_or_nets) |
| elif isinstance(steps_or_nets, Net): |
| step.AddNet(steps_or_nets) |
| elif isinstance(steps_or_nets, list): |
| if all(isinstance(x, Net) for x in steps_or_nets): |
| map(step.AddNet, steps_or_nets) |
| else: |
| map(step.AddSubstep, map(to_execution_step, steps_or_nets)) |
| elif steps_or_nets: |
| raise ValueError( |
| 'steps_or_nets must be a step, a net, or a list of nets or steps.') |
| return step |