| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import networkx as nx |
| import collections |
| import time |
| import copy |
| from caffe2.python import workspace |
| |
| import logging |
| |
| log = logging.getLogger("memonger") |
| log.setLevel(logging.INFO) |
| LiveRange = collections.namedtuple('LiveRange', ["defined", "used"]) |
| |
| |
| def share_grad_blobs(net, losses, param_grads, namescope): |
| ''' |
| Implements similar optimization as Torch's shareGradInput(): |
| for the gradients that are passed between layers, share blobs between |
| operators when possible. This yields significant memory savings with |
| deep networks. |
| |
| Returns an optimized protobuf (assign to net._net) |
| ''' |
| def is_grad_blob(b): |
| name = str(b) |
| # Note: need to look at _{namescope} pattern as it matches |
| # to handle the auto-split gradients |
| return "_grad" in name and (name.startswith(namescope) or |
| name.startswith("_" + namescope)) and name not in param_grads |
| |
| def is_grad_op(op): |
| # TODO: something smarter |
| for inp in op.input: |
| if is_grad_blob(inp): |
| return True |
| for out in op.output: |
| if is_grad_blob(out): |
| return True |
| return False |
| |
| start_time = time.time() |
| log.warn("NOTE: Executing *experimental* memonger to " + |
| "optimize gradient memory") |
| |
| # Collect ops that have something to do with |
| # gradients |
| if not namescope.endswith("/"): |
| namescope += "/" |
| |
| netproto = copy.deepcopy(net.Proto()) |
| grad_ops = [op for op in netproto.op if is_grad_op(op)] |
| |
| # Create mapping from blobs to ops |
| blobs_to_ops = collections.defaultdict(lambda: []) |
| blob_input_count = collections.defaultdict(lambda: 0) |
| op_inputs = collections.defaultdict(lambda: 0) |
| op_visit_count = collections.defaultdict(lambda: 0) |
| for i, op in enumerate(grad_ops): |
| for inp in op.input: |
| if is_grad_blob(inp) or inp in losses: |
| # Ignore in-place transformation ops (self cycles) |
| if inp not in op.output: |
| blobs_to_ops[inp].append(i) |
| op_inputs[i] += 1 |
| |
| # Traverse operators starting from the loss blobs. |
| # Keep tabs on when blobs are seen first and last, and also |
| # when operators have their input satisfied. Share blobs only |
| # under same branch, avoiding problems with parallel workers. |
| output_blobs = set() |
| mapping = {} |
| |
| def descend(op_idx, free_blobs): |
| cur_op = grad_ops[op_idx] |
| new_free_blobs = set() |
| for inp in cur_op.input: |
| if is_grad_blob(inp): |
| blob_input_count[inp] += 1 |
| if blob_input_count[inp] == len(blobs_to_ops[inp]): |
| actual_blob = inp if inp not in mapping else mapping[inp] |
| new_free_blobs.add(actual_blob) |
| |
| for outp in cur_op.output: |
| if is_grad_blob(outp): |
| if outp not in output_blobs: |
| # First seen this blob as output, can assign to a free blob |
| for freeb in free_blobs: |
| mapping[outp] = freeb |
| free_blobs.remove(freeb) |
| break |
| |
| output_blobs.add(outp) |
| |
| free_blobs.update(new_free_blobs) |
| |
| first_branch = True |
| for outp in cur_op.output: |
| for inp_op_idx in blobs_to_ops[outp]: |
| op_visit_count[inp_op_idx] += 1 |
| |
| # Descend only if we have satisfied all inputs |
| if op_visit_count[inp_op_idx] == op_inputs[inp_op_idx]: |
| free_blobs_fwd = free_blobs if first_branch else set() |
| first_branch = False |
| descend(inp_op_idx, free_blobs_fwd) |
| |
| # Start DFS from the losses |
| for loss in losses: |
| for op_idx in blobs_to_ops[loss]: |
| descend(op_idx, set()) |
| |
| # Rename the shared blobs |
| shared_blobs = set(mapping.values()) |
| renamed = {} |
| for j, b in enumerate(shared_blobs): |
| renamed[b] = namescope + "__m{}_".format(j) |
| |
| # Final mapping |
| for k, v in mapping.items(): |
| mapping[k] = renamed[v] |
| |
| # Add the originators |
| mapping.update(renamed) |
| log.info("Remapping {} blobs, using {} shared".format( |
| len(mapping), len(renamed), |
| )) |
| apply_assignments(netproto, mapping) |
| log.info("Gradient memory optimization took {} secs".format( |
| time.time() - start_time), |
| ) |
| return netproto |
| |
| |
| def topological_sort_traversal(g): |
| return nx.topological_sort(g) |
| |
| |
| def compute_ranges(linearized_ops): |
| blobs = collections.defaultdict(lambda: LiveRange(defined=None, used=None)) |
| for i, op in enumerate(linearized_ops): |
| for blob in op.input: |
| used = blobs[blob].used |
| if used is None: |
| used = i |
| else: |
| used = max(used, i) |
| blobs[blob] = blobs[blob]._replace(used=used) |
| for blob in op.output: |
| defined = blobs[blob].defined |
| if defined is None: |
| defined = i |
| else: |
| defined = min(defined, i) |
| blobs[blob] = blobs[blob]._replace(defined=defined) |
| |
| return blobs |
| |
| |
| def is_compatible(candidate_range, assignment, static_blobs): |
| (name, range_) = assignment[-1] |
| if name in static_blobs: |
| return False |
| if candidate_range.defined is None or range_.defined is None \ |
| or range_.used is None: |
| return False |
| return candidate_range.defined > range_.used |
| |
| |
| def compute_blob_assignments(assignments): |
| blob_assignments = {} |
| for assignment in assignments: |
| if len(assignment) == 1: |
| continue |
| last_blob, _ = assignment[-1] |
| for (blob, _) in assignment: |
| blob_assignments[blob] = last_blob |
| return blob_assignments |
| |
| |
| def compute_assignments(ranges, static_blobs): |
| # Sort the ranges based on when they are last used. |
| # If LiveRange.used is None, then the blob is never used and could |
| # be consumed externally. Sort these to the end of the list as opposed |
| # to the beginning so that they can be shared as well. |
| ranges = sorted( |
| list(ranges.items()), |
| key=lambda p: (p[1].used is None, p[1].used), |
| ) |
| assignments = [] |
| for (name, range_) in ranges: |
| assigned = False |
| for assignment in assignments: |
| if is_compatible(range_, assignment, static_blobs): |
| assignment.append((name, range_)) |
| assigned = True |
| break |
| if assigned: |
| continue |
| assignments.append([(name, range_)]) |
| return assignments |
| |
| |
| def compute_interference_graph(ops): |
| g = nx.DiGraph() |
| for i, op in enumerate(ops): |
| g.add_node(i, op=op) |
| for i, parent_op in enumerate(ops): |
| for j, child_op in enumerate(ops): |
| if i == j: |
| continue |
| if any(output in child_op.input for output in parent_op.output): |
| deps = set(child_op.input).intersection(parent_op.output) |
| g.add_edge(i, j, deps=deps) |
| assert nx.is_directed_acyclic_graph(g), child_op |
| return g |
| |
| |
| Optimization = collections.namedtuple( |
| 'Optimization', ['net', 'assignments', 'blob_assignments']) |
| |
| |
| def apply_assignments(net, blob_assignments): |
| def canonical_name(blob): |
| if blob not in blob_assignments: |
| return blob |
| return blob_assignments[blob] |
| |
| for op in net.op: |
| for i, input_ in enumerate(op.input): |
| op.input[i] = canonical_name(input_) |
| for i, output in enumerate(op.output): |
| op.output[i] = canonical_name(output) |
| |
| |
| def optimize_interference(net, static_blobs, |
| ordering_function=topological_sort_traversal): |
| """ |
| 1) Use a BFS traversal of the execution graph to generate an |
| ordering of the node executions. |
| 2) Generate use-def ranges for each `blob` in the BFS traversal |
| order. |
| 3) Assign blobs to `canonical blobs` |
| 4) Rename blobs to canonical blobs |
| """ |
| net = copy.deepcopy(net) |
| g = compute_interference_graph(net.op) |
| ordering = ordering_function(g) |
| linearized_ops = [net.op[i] for i in ordering] |
| |
| # Reorder ops in net based on the computed linearlized order. |
| # If the graph has multiple topological orderings and if the NetDef's |
| # ordering differs from the order used to compute ranges, then the |
| # runtime might end up overwriting blobs before they are used. |
| del net.op[:] |
| net.op.extend(linearized_ops) |
| |
| ranges = compute_ranges(linearized_ops) |
| assignments = compute_assignments(ranges, static_blobs) |
| blob_assignments = compute_blob_assignments(assignments) |
| apply_assignments(net, blob_assignments) |
| return Optimization( |
| net=net, |
| blob_assignments=blob_assignments, |
| assignments=assignments) |
| |
| Statistics = collections.namedtuple( |
| 'Statistics', ['baseline_nbytes', 'optimized_nbytes']) |
| |
| |
| def compute_statistics(assignments): |
| def blob_nbytes(blob): |
| return workspace.FetchBlob(blob).nbytes |
| blob_bytes = { |
| blob: blob_nbytes(blob) for assignment in assignments |
| for (blob, _) in assignment} |
| baseline_nbytes = sum(v for _, v in blob_bytes.iteritems()) |
| optimized_nbytes = sum( |
| max(blob_bytes[blob] for (blob, _) in assignment) |
| for assignment in assignments) |
| return Statistics( |
| baseline_nbytes=baseline_nbytes, |
| optimized_nbytes=optimized_nbytes) |