Rewrite memonger DAG in C++.

Summary: This diff replaces the main of the memonger for dag algorithm _compute_blob_recycling_for_dag with a c++ implementation.

Reviewed By: akyrola

Differential Revision: D5544219

fbshipit-source-id: 9f868880c8d0eb997ad3dd39433f9d0b9216d303
diff --git a/caffe2/core/memonger.cc b/caffe2/core/memonger.cc
index 490afaf..c093fb2 100644
--- a/caffe2/core/memonger.cc
+++ b/caffe2/core/memonger.cc
@@ -64,7 +64,8 @@
           mapping[inp] = inp;
 
           // Safety check to prevent double-memongering nets.
-          string shared_blob = "__m" + to_string(renaming.size()) + "_shared";
+          string shared_blob =
+              "__m" + caffe2::to_string(renaming.size()) + "_shared";
           if (all_blobs.find(shared_blob) != all_blobs.end()) {
             LOG(INFO) << "Net was already memongered!";
             return net;
@@ -119,5 +120,403 @@
   LOG(INFO) << "optimized net using " << renaming.size() << " shared blobs";
   return optim_net;
 }
+
+class ComputeBlobRecyclingForDag {
+ public:
+  explicit ComputeBlobRecyclingForDag(const int size)
+      : op_inputs_(size),
+        op_visited_count_(size),
+        op_token_deposit_(size),
+        op_visited_(size, false) {}
+  NetDef OptimizeNet(
+      const NetDef& net,
+      const std::vector<string>& heads,
+      const std::vector<int>& op_indices,
+      const std::unordered_set<string>& shareable_blob_names,
+      const string& namescope,
+      const std::unordered_set<string>& dont_share_blob_names,
+      const std::unordered_map<string, vector<int>>& blob_shapes) {
+    // Construct the set of input blobs.
+    std::unordered_set<string> heads_blobs_set(heads.begin(), heads.end());
+
+    // Construct the set of output blobs we want to optimize.
+    for (const int op_index : op_indices) {
+      for (const auto& output : net.op(op_index).output()) {
+        optim_op_outputs_.insert(output);
+      }
+    }
+
+    // Compute operators in degree (op_inputs_) and initialize how many ops are
+    // sharing input blobs (share_counts_).
+    // Note: We have to handle the cases where output blobs are shared.
+    std::unordered_map<string, int> blob_seen;
+    for (const int op_index : op_indices) {
+      for (const auto& input : net.op(op_index).input()) {
+        if (has_key(shareable_blob_names, input) ||
+            has_key(heads_blobs_set, input)) {
+          if (has_key(optim_op_outputs_, input)) {
+            CAFFE_ENFORCE(
+                blob_seen.find(input) != blob_seen.end(),
+                "Input ",
+                input,
+                " was not output by an op before");
+            op_inputs_[op_index] += blob_seen[input];
+          } else {
+            share_counts_[input] = 1;
+          }
+          blob_to_ops_[input].push_back(op_index);
+        }
+      }
+      for (const auto& output : net.op(op_index).output()) {
+        blob_seen[output] += 1;
+      }
+    }
+
+    // The main recursive call. Here we do start DFS in the operator graph
+    // from the input blobs.
+    for (const auto& input_blob : heads) {
+      for (const int op_index : blob_to_ops_[input_blob]) {
+        if (!op_visited_[op_index]) {
+          vector<std::pair<int, string>> free_blobs;
+          std::unordered_set<int> tokens{tokens_counter_++};
+          process_op(
+              net,
+              shareable_blob_names,
+              namescope,
+              dont_share_blob_names,
+              blob_shapes,
+              op_index,
+              &free_blobs,
+              &tokens);
+        }
+      }
+    }
+
+    // Rename mapped blobs.
+    std::unordered_map<string, string> renamed;
+    int name_idx = 0;
+    std::unordered_set<string> mapped_blobs_set;
+    for (const auto& mapped_blob : mapping_) {
+      mapped_blobs_set.insert(mapped_blob.second);
+      if (has_key(optim_op_outputs_, mapped_blob.second)) {
+        if (renamed.find(mapped_blob.second) == renamed.end()) {
+          renamed.insert(
+              {mapped_blob.second,
+               namescope + "__m" + caffe2::to_string(name_idx++) + "_shared"});
+        }
+      } else {
+        renamed.insert({mapped_blob.second, mapped_blob.second});
+      }
+    }
+
+    // Recursively rename mapped_blobs.
+    mapping_.insert(renamed.begin(), renamed.end());
+    bool had_changes = true;
+    while (had_changes) {
+      had_changes = false;
+      for (const auto mapped_blob : mapping_) {
+        if (has_key(renamed, mapped_blob.second) &&
+            renamed[mapped_blob.second] != mapped_blob.second) {
+          renamed[mapped_blob.first] = renamed[mapped_blob.second];
+          mapping_[mapped_blob.first] = renamed[mapped_blob.first];
+        }
+      }
+    }
+
+    NetDef optimized_net = net;
+    // Rename optimized_net blobs.
+    for (int i = 0; i < optimized_net.op_size(); ++i) {
+      for (int j = 0; j < optimized_net.op(i).input_size(); ++j) {
+        const string& input_name =
+            get_blob_or_mapped_blob(optimized_net.op(i).input(j));
+        optimized_net.mutable_op(i)->set_input(j, input_name);
+      }
+
+      for (int j = 0; j < optimized_net.op(i).output_size(); ++j) {
+        auto output_name =
+            get_blob_or_mapped_blob(optimized_net.op(i).output(j));
+        optimized_net.mutable_op(i)->set_output(j, output_name);
+      }
+    }
+
+    LOG(INFO) << "Remapping " << mapping_.size() << " using "
+              << mapped_blobs_set.size() << " shared blobs.";
+    if (floats_saved_ > 0) {
+      LOG(INFO) << "Memoger saved approximately : "
+                << (floats_saved_ * 4.0 / 1024.0 / 1024.0) << " MB.";
+    }
+
+    return optimized_net;
+  }
+
+ private:
+  template <typename K, typename V>
+  inline bool has_key(const std::unordered_map<K, V>& in_map, const K& key) {
+    return in_map.find(key) != in_map.end();
+  }
+
+  template <typename K>
+  inline bool has_key(const std::unordered_set<K>& in_set, const K& key) {
+    return in_set.find(key) != in_set.end();
+  }
+
+  void process_op(
+      const NetDef& net,
+      const std::unordered_set<string>& shareable_blob_names,
+      const string& namescope,
+      const std::unordered_set<string>& dont_share_blob_names,
+      const std::unordered_map<string, vector<int>>& blob_shapes,
+      int op_index,
+      std::vector<std::pair<int, string>>* free_blobs,
+      std::unordered_set<int>* tokens) {
+    // The tokens we have now is the union of current tokens operator is holding
+    // and tokens pushed from parents.
+    tokens->insert(
+        op_token_deposit_[op_index].begin(), op_token_deposit_[op_index].end());
+    op_token_deposit_[op_index].clear();
+    CAFFE_ENFORCE(!op_visited_[op_index]);
+    op_visited_[op_index] = true;
+
+    const OperatorDef& current_op = net.op(op_index);
+
+    // The set of freed input blobs by processing current op.
+    std::vector<std::pair<int, string>> new_free_blobs;
+    std::unordered_set<string> new_free_blobs_set;
+
+    // Now update blob tokens.
+    for (const auto& input : current_op.input()) {
+      const auto& actual_blob = get_blob_or_mapped_blob(input);
+      req_tokens_[actual_blob].insert(tokens->begin(), tokens->end());
+      if (actual_blob != input) {
+        req_tokens_[input].insert(tokens->begin(), tokens->end());
+      }
+    }
+    for (const auto& output : current_op.output()) {
+      const auto& actual_blob = get_blob_or_mapped_blob(output);
+      req_tokens_[actual_blob].insert(tokens->begin(), tokens->end());
+      if (actual_blob != output) {
+        req_tokens_[output].insert(tokens->begin(), tokens->end());
+      }
+    }
+
+    // Increment blob count and check if we can free input blobs.
+    for (const auto& input : current_op.input()) {
+      if (has_key(shareable_blob_names, input)) {
+        blob_input_count_[input]++;
+        if (blob_input_count_[input] == blob_to_ops_[input].size()) {
+          const string& actual_blob = get_blob_or_mapped_blob(input);
+          if (!has_key(dont_share_blob_names, actual_blob)) {
+            new_free_blobs.emplace_back(
+                -share_counts_[actual_blob], actual_blob);
+            new_free_blobs_set.insert(actual_blob);
+          }
+        }
+      }
+    }
+
+    // Check if we can recycle free blobs and use it as output blob.
+    for (const auto& output : current_op.output()) {
+      if (has_key(shareable_blob_names, output) &&
+          !has_key(processed_output_blobs_, output) &&
+          !has_key(new_free_blobs_set, output)) {
+        const string freed_blob =
+            get_free_blob(output, blob_shapes, tokens, free_blobs);
+        if (freed_blob != "") {
+          req_tokens_[freed_blob].insert(tokens->begin(), tokens->end());
+          share_counts_[freed_blob]++;
+          mapping_[output] = freed_blob;
+        }
+        processed_output_blobs_.insert(output);
+      }
+    }
+
+    // Insert new freed blobs.
+    std::unordered_set<string> free_blob_set;
+    for (const auto& free_blob : *free_blobs) {
+      free_blob_set.insert(free_blob.second);
+    }
+    for (const auto& new_free_blob : new_free_blobs) {
+      if (!has_key(free_blob_set, new_free_blob.second)) {
+        free_blobs->push_back(new_free_blob);
+        if (blob_shapes.size() > 0) {
+          if (!has_key(blob_sizes_, new_free_blob.second)) {
+            blob_sizes_.insert(
+                {new_free_blob.second,
+                 infer_blob_size(new_free_blob.second, blob_shapes)});
+          }
+        }
+        std::push_heap(
+            free_blobs->begin(),
+            free_blobs->end(),
+            std::greater<std::pair<int, string>>());
+      }
+    }
+
+    int num_branches = 0;
+    for (const auto& output : current_op.output()) {
+      num_branches += blob_to_ops_[output].size();
+    }
+
+    for (const auto& output : current_op.output()) {
+      for (const auto& input_op_index : blob_to_ops_[output]) {
+        op_visited_count_[input_op_index]++;
+        if (op_visited_count_[input_op_index] == op_inputs_[input_op_index]) {
+          std::unordered_set<int> new_tokens;
+          new_tokens.insert(tokens->begin(), tokens->end());
+          if (num_branches > 1) {
+            new_tokens.insert(tokens_counter_++);
+          }
+          process_op(
+              net,
+              shareable_blob_names,
+              namescope,
+              dont_share_blob_names,
+              blob_shapes,
+              input_op_index,
+              free_blobs,
+              &new_tokens);
+        } else {
+          if (!op_visited_[input_op_index]) {
+            op_token_deposit_[input_op_index].insert(
+                tokens->begin(), tokens->end());
+          }
+        }
+      }
+    }
+  }
+
+  inline int infer_blob_size(
+      const string& blob_name,
+      const std::unordered_map<string, vector<int>>& blob_shapes) {
+    const auto& blob_shapes_iter = blob_shapes.find(blob_name);
+    if (blob_shapes_iter == blob_shapes.end()) {
+      return 0;
+    }
+    int size = 1;
+    for (int i = 0; i < blob_shapes_iter->second.size(); ++i) {
+      size *= blob_shapes_iter->second[i];
+    }
+    return size;
+  }
+
+  inline string get_blob_or_mapped_blob(const string& blob_name) {
+    auto mapped_blob = mapping_.find(blob_name);
+    if (mapped_blob == mapping_.end()) {
+      return blob_name;
+    } else {
+      return mapped_blob->second;
+    }
+  }
+
+  // Rturns true if the op that generates that blob acquires all tokens.
+  inline bool can_use_blob(
+      const string& blob_name,
+      std::unordered_set<int>* tokens) {
+    for (const int token : req_tokens_[blob_name]) {
+      if (tokens->find(token) == tokens->end()) {
+        return false;
+      }
+    }
+    return true;
+  };
+
+  // Returns the name of the blob that we are going to map blob_name into.
+  inline string get_free_blob(
+      const string& blob_name,
+      const std::unordered_map<string, vector<int>>& blob_shapes,
+      std::unordered_set<int>* tokens,
+      std::vector<std::pair<int, string>>* free_blobs) {
+    string freed_blob = "";
+    if (blob_shapes.size() == 0) {
+      std::vector<std::pair<int, string>> cant_use_blobs;
+      while (free_blobs->size() > 0) {
+        std::pop_heap(
+            free_blobs->begin(),
+            free_blobs->end(),
+            std::greater<std::pair<int, string>>());
+        const auto cand_free_blob = free_blobs->back();
+        free_blobs->pop_back();
+        if (can_use_blob(cand_free_blob.second, tokens)) {
+          freed_blob = cand_free_blob.second;
+          break;
+        } else {
+          cant_use_blobs.push_back(cand_free_blob);
+        }
+      }
+      for (const auto& cant_use_blob : cant_use_blobs) {
+        free_blobs->push_back(cant_use_blob);
+        std::push_heap(
+            free_blobs->begin(),
+            free_blobs->end(),
+            std::greater<std::pair<int, string>>());
+      }
+    } else {
+      // Heuristic to choose the largest blob to fit output thats
+      // slightly less than blob_size.
+      const int blob_size = infer_blob_size(blob_name, blob_shapes);
+      int best_size = -1;
+      int free_blob_index = -1;
+      for (int i = 0; i < free_blobs->size(); ++i) {
+        const string& cb_name = (*free_blobs)[i].second;
+        if (can_use_blob(cb_name, tokens)) {
+          const int cand_bz = blob_sizes_[cb_name];
+          CAFFE_ENFORCE(blob_sizes_.find(cb_name) != blob_sizes_.end());
+          if (cand_bz >= best_size) {
+            if (best_size < blob_size || best_size >= cand_bz) {
+              best_size = cand_bz;
+              free_blob_index = i;
+            }
+          }
+        }
+      }
+      if (free_blob_index != -1) {
+        floats_saved_ += best_size;
+        freed_blob = (*free_blobs)[free_blob_index].second;
+        free_blobs->erase(free_blobs->begin() + free_blob_index);
+      }
+    }
+    return freed_blob;
+  };
+
+  int tokens_counter_ = 1;
+  int floats_saved_ = 0;
+  // blob_name -> Op edges.
+  std::unordered_map<string, std::vector<int>> blob_to_ops_;
+  // Current Op in degree.
+  std::unordered_map<string, int> blob_input_count_;
+  // Op in degree.
+  std::vector<int> op_inputs_;
+  // Current Op visit counts.
+  std::vector<int> op_visited_count_;
+  std::unordered_map<string, int> share_counts_;
+  std::unordered_map<string, int> blob_sizes_;
+  std::unordered_map<string, std::unordered_set<int>> req_tokens_;
+  std::vector<std::unordered_set<int>> op_token_deposit_;
+  std::unordered_set<string> optim_op_outputs_;
+  std::unordered_map<string, string> mapping_;
+  // The set of output blobs we already processed.
+  std::unordered_set<string> processed_output_blobs_;
+  std::vector<bool> op_visited_;
+};
+
+NetDef compute_blob_recycling_for_dag(
+    const NetDef& net,
+    const std::vector<string>& heads,
+    const std::vector<int>& op_indices,
+    const std::unordered_set<string>& shareable_blob_names,
+    const string& namescope,
+    const std::unordered_set<string>& dont_share_blob_names,
+    const std::unordered_map<string, vector<int>>& blob_shapes) {
+  ComputeBlobRecyclingForDag memonger(net.op_size());
+  return memonger.OptimizeNet(
+      net,
+      heads,
+      op_indices,
+      shareable_blob_names,
+      namescope,
+      dont_share_blob_names,
+      blob_shapes);
 }
-}
+
+} // memonger
+} // caffe2
diff --git a/caffe2/core/memonger.h b/caffe2/core/memonger.h
index 20ac81f..fe65ae5 100644
--- a/caffe2/core/memonger.h
+++ b/caffe2/core/memonger.h
@@ -1,6 +1,8 @@
 #ifndef CAFFE2_CORE_MEMONGER_H_
 #define CAFFE2_CORE_MEMONGER_H_
 
+#include <unordered_set>
+
 #include "caffe2/core/common.h"
 #include "caffe2/core/workspace.h"
 #include "caffe2/proto/caffe2.pb.h"
@@ -11,7 +13,17 @@
 NetDef optimize_inference_net(
     const NetDef& net,
     const std::set<string>& static_blobs);
-}
-}
+
+NetDef compute_blob_recycling_for_dag(
+    const NetDef& net,
+    const std::vector<string>& heads,
+    const std::vector<int>& op_indices,
+    const std::unordered_set<string>& shareable_blob_names,
+    const string& namescope,
+    const std::unordered_set<string>& dont_share_blob_names,
+    const std::unordered_map<string, vector<int>>& blob_shapes);
+
+} // memonger
+} // caffe2
 
 #endif
diff --git a/caffe2/python/memonger.py b/caffe2/python/memonger.py
index d9f1755..b06ed0e 100644
--- a/caffe2/python/memonger.py
+++ b/caffe2/python/memonger.py
@@ -8,13 +8,11 @@
 import networkx as nx
 import collections
 import time
-import heapq
 import copy
 from caffe2.python import workspace
 from caffe2.proto import caffe2_pb2
 import enum
 import logging
-import numpy as np
 from future.utils import viewitems, viewvalues
 import caffe2.python._import_c_extension as C
 
@@ -74,17 +72,39 @@
     activations = set(activations[:-2])
 
     # Gradient ops
-    grad_ops = [op for op in netproto.op if is_grad_op(op)]
-    return _compute_blob_recycling_for_dag(
-        netproto,
-        losses,
-        grad_ops,
-        lambda b: is_grad_blob(b) or (share_activations and b in activations),
-        namescope,
-        {} if dont_share_blobs is None else dont_share_blobs,
-        blob_shapes
+    grad_op_indices = []
+    for idx, op in enumerate(netproto.op):
+        if (is_grad_op(op)):
+            grad_op_indices.append(idx)
+
+    shared_blobs = set()
+    for op in net.Proto().op:
+        for b in list(op.input) + list(op.output):
+            if is_grad_blob(b) or (share_activations and b in activations):
+                shared_blobs.add(b)
+    start_time = time.time()
+    optim_str = C.memonger_compute_blob_recycling_for_dag(
+        netproto.SerializeToString(),
+        [str(s).encode('utf-8') for s in losses],
+        grad_op_indices,
+        set(str(s).encode('utf-8') for s in shared_blobs),
+        namescope.encode('utf-8'),
+        set() if dont_share_blobs is None else dont_share_blobs,
+        {} if blob_shapes is None else blob_shapes
     )
 
+    log.info("Memonger memory optimization took {} secs".format(
+        time.time() - start_time),
+    )
+
+    optim = caffe2_pb2.NetDef()
+    optim.ParseFromString(optim_str)
+    assert verify_graph_equality(net.Proto(), optim), \
+        "Memonger graph is not equal to original."
+    assert verify_inplace_blobs(net.Proto(), optim), \
+        "Inplace assignments differ in memonger net."
+    return optim
+
 
 def optimize_inference_for_dag(net, input_blobs, namescope=""):
     netproto = copy.deepcopy(net.Proto())
@@ -94,311 +114,48 @@
     def is_activation_blob(b):
         return b not in external_input and b not in external_output
 
+    activation_blobs = set()
     seen_as_output = set()
     ops = list(net.Proto().op)
+    op_indices = [index for index, op in enumerate(net.Proto().op)]
 
     # Sanity check: check that all external inputs are properlyh accounted
     # and that no gradient ops are included in 'net'
     for op in ops:
         for b in op.input:
-            if is_activation_blob(b) and b not in seen_as_output:
-                assert False, "{} not in external input".format(b)
+            if is_activation_blob(b):
+                activation_blobs.add(b)
+                if b not in seen_as_output:
+                    assert False, "{} not in external input".format(b)
+        for b in op.output:
+            if is_activation_blob(b):
+                activation_blobs.add(b)
         seen_as_output = seen_as_output.union(set(op.output))
         assert not op.is_gradient_op, \
             "You can only pass inference-only nets to optimize_inference_for_dag"
-
-    return _compute_blob_recycling_for_dag(
-        netproto, input_blobs, ops, is_activation_blob,
-        namescope, set(), None,
+    start_time = time.time()
+    optim_str = C.memonger_compute_blob_recycling_for_dag(
+        netproto.SerializeToString(),
+        [str(s).encode('utf-8') for s in input_blobs],
+        op_indices,
+        set(str(s).encode('utf-8') for s in activation_blobs),
+        namescope.encode('utf-8'),
+        set(),
+        {}
     )
 
-
-def _compute_blob_recycling_for_dag(
-    netproto, heads, ops, is_shareable,
-    namescope, dont_share_blobs, blob_shapes=None,
-):
-    '''
-    Computes a blob recycling by traversing the computation DAG. The resulting
-    model can be executed safely on a DAGNet.
-    '''
-    start_time = time.time()
-    if dont_share_blobs is not None:
-        dont_share_blobs = set([str(b) for b in dont_share_blobs])
-
-    # Create mapping from blobs to ops
-    origproto = copy.deepcopy(netproto)
-    blobs_to_ops = collections.defaultdict(lambda: [])
-    blob_input_count = collections.defaultdict(lambda: 0)
-    op_inputs = collections.defaultdict(lambda: 0)
-    op_visit_count = collections.defaultdict(lambda: 0)
-    share_counts = collections.defaultdict(lambda: 0)
-    req_tokens = collections.defaultdict(lambda: set())
-    op_token_deposit = [set() for _ in ops]
-
-    blob_sizes = {} if blob_shapes is not None else None
-
-    # First figure out which of the shareable blobs
-    # are 'internal' to the optimization. For example, if optimizing
-    # only gradient ops, then activation blobs will be 'external' as they
-    # are not output by these ops.
-    optim_op_outputs = set()
-    for op in ops:
-        optim_op_outputs.update(set(op.output))
-
-    blob_seen = collections.defaultdict(lambda: 0)
-    for i, op in enumerate(ops):
-        for inp in op.input:
-            if is_shareable(inp) or inp in heads:
-                if inp in optim_op_outputs:
-                    blobs_to_ops[inp].append(i)
-                    assert blob_seen[inp] > 0, \
-                        "Input {} was not output by an op before".format(inp)
-                    op_inputs[i] += blob_seen[inp]
-                else:
-                    # For external blobs, we don't increase the op_inputs
-                    # count.
-                    blobs_to_ops[inp].append(i)
-                    share_counts[inp] = 1
-        for outp in op.output:
-            blob_seen[outp] += 1
-
-    output_blobs = set()
-    mapping = {}
-    unknown_shapes = set()
-
-    # Helper function to return blob size based on shape inference.
-    # If we don't have shape inference available, return 0.
-    def infer_blob_size(b):
-        if b in blob_shapes:
-            return np.prod(blob_shapes[b])
-        else:
-            unknown_shapes.add(b)
-        return 0
-
-    global token_seq
-    token_seq = 0
-
-    # Creates a next "token". Tokens are used to to keep track of
-    # dependendencies: a blob can be replaced by another only if that
-    # blob "holds" all tokens currently in scope.
-    def next_token():
-        global token_seq
-        token_seq += 1
-        return token_seq
-
-    saved_count = 0
-
-    # Main recursive function. We start recursion from the "heads" and
-    # only descend on an operator when all its inputs have been 'satisfied'.
-    # That is, all parent operators have been visited.
-    def descend(op_idx, free_blobs, tokens):
-        # Check if there are tokens left at this operator from a
-        # parent operator.
-        tokens = tokens.union(op_token_deposit[op_idx])
-        op_token_deposit[op_idx] = None
-        cur_op = ops[op_idx]
-
-        # new_free_blobs contains the blobs that we will release after
-        # visiting this op
-        new_free_blobs = set()
-        saved = 0
-
-        # Update the tokens assigned to blobs to be union of the
-        # tokens we are currently holding and the tokens already held
-        # by that blob.
-        for b in list(cur_op.input) + list(cur_op.output):
-            actual_blob = b if b not in mapping else mapping[b]
-            req_tokens[b] = req_tokens[b].union(tokens)
-            if actual_blob != b:
-                # This blob has been assigned to another (recycled) blob,
-                # so update the token holdings of the recycled blob.
-                req_tokens[actual_blob] = req_tokens[actual_blob].union(tokens)
-
-        # Check each input and increment the counters for each of the input
-        # blobs.
-        for inp in cur_op.input:
-            if is_shareable(inp):
-                blob_input_count[inp] += 1
-                if blob_input_count[inp] == len(blobs_to_ops[inp]):
-                    # This input blob has been now consumed, so we
-                    # can release it to be recycled. If it was replaced
-                    # by another recycled blob, release the recycled blob
-                    # instead.
-                    actual_blob = inp if inp not in mapping else mapping[inp]
-                    if actual_blob not in dont_share_blobs:
-                        new_free_blobs.add(
-                            (-share_counts[actual_blob], actual_blob),
-                        )
-
-        def can_be_used(blob, cur_tokens):
-            # Do we have all required tokens, and this one
-            # was not released in this op?
-            for (_cnt, b) in new_free_blobs:
-                if b == blob:
-                    return False
-            return len(req_tokens[blob] - cur_tokens) == 0
-
-        # Check each output to see if we see the output the first time (i.e
-        # it is created by this op). if it is then, we can replace it with
-        # a recycled blob, if available.
-        for outp in cur_op.output:
-            if is_shareable(outp):
-                if outp not in output_blobs:
-                    # First seen this blob as output, can assign to a free blob
-                    freeb = None
-
-                    # We have two algorithms for choosing the blob to replace
-                    # this one. One that uses size information and another
-                    # that uses a priority queue that prefers blobs that are
-                    # have been shared before.
-                    if blob_sizes is None:
-                        put_back = []
-                        while len(free_blobs) > 0:
-                            (negcnt, cand_freeb) = heapq.heappop(free_blobs)
-                            if can_be_used(cand_freeb, tokens):
-                                freeb = cand_freeb
-                                break
-                            else:
-                                put_back.append((negcnt, cand_freeb))
-                        for cnt, b in put_back:
-                            heapq.heappush(free_blobs, (cnt, b))
-                    else:
-                        bsize = infer_blob_size(outp)
-                        best_blob = None
-                        best_size = -1
-
-                        # Heuristic to choose the most suitably sized blob
-                        for b in free_blobs:
-                            if can_be_used(b, tokens):
-                                sz = blob_sizes[b]
-                                if sz >= best_size:
-                                    if best_size < bsize or best_size >= sz:
-                                        best_size = sz
-                                        best_blob = b
-
-                        freeb = best_blob
-
-                        if freeb is not None:
-                            free_blobs.remove(freeb)
-                            saved += bsize
-
-                    # "freeb" is the blob output to be replaced with. We
-                    # update its tokens to include the tokens being held
-                    # now.
-                    if freeb is not None:
-                        req_tokens[freeb] = req_tokens[freeb].union(tokens)
-                        mapping[outp] = freeb
-                        share_counts[freeb] += 1
-
-                output_blobs.add(outp)
-
-        # Process blobs released during this op visit. Depending
-        # on whether we have blob sizes or not, we store the list
-        # of free blobs differently (NOTE: this should be unified).
-        for (cnt, nf) in new_free_blobs:
-            already_inserted = False
-            # Note: we prevent double insertion, but it can
-            # happen because of parallel branches. Token management
-            # ensures free blobs are handled correctly.
-            if blob_sizes is None:
-                for _c, b in free_blobs:
-                    if b == nf:
-                        already_inserted = True
-                if not already_inserted:
-                    heapq.heappush(free_blobs, (cnt, nf))
-            else:
-                if nf not in blob_sizes:
-                    blob_sizes[nf] = infer_blob_size(outp)
-                if nf in free_blobs:
-                    already_inserted = True
-                if not already_inserted:
-                    free_blobs.append(nf)
-
-        num_branches = 0
-        # Count branches
-        for outp in cur_op.output:
-            for _ in blobs_to_ops[outp]:
-                num_branches += 1
-
-        # Here we process each output again and see if we can descend
-        # down the operator graph.
-        for outp in cur_op.output:
-            for inp_op_idx in blobs_to_ops[outp]:
-                op_visit_count[inp_op_idx] += 1
-
-                # Descend only if we have satisfied all inputs
-                if op_visit_count[inp_op_idx] == op_inputs[inp_op_idx]:
-                    assert inp_op_idx != op_idx
-                    new_tokens = tokens
-                    if num_branches > 1:
-                        # Optimization
-                        new_tokens = tokens.union(set([next_token()]))
-                    saved_desc = descend(
-                        inp_op_idx,
-                        free_blobs,
-                        new_tokens,
-                    )
-                    saved += saved_desc
-
-                else:
-                    # Leave my tokens here so that they can be grabbed
-                    # when we visit the operator (after all inputs have been
-                    # satisfied).
-                    if op_token_deposit[inp_op_idx] is not None:
-                        op_token_deposit[inp_op_idx] = \
-                            op_token_deposit[inp_op_idx].union(tokens)
-
-        return saved
-
-    # Start DFS from the heads' (losses or inputs)
-    for head_blob in heads:
-        for op_idx in blobs_to_ops[head_blob]:
-            if op_token_deposit[op_idx] is not None:
-                saved = descend(op_idx, [], set([next_token()]))
-                saved_count += saved
-
-    # Rename the shared blobs
-    shared_blobs = set(viewvalues(mapping))
-    renamed = {}
-    for j, b in enumerate(shared_blobs):
-        if b in optim_op_outputs:
-            renamed[b] = namescope + "__m{}_shared".format(j)
-        else:
-            renamed[b] = b
-
-    # Update the mapping recursively
-    mapping.update(renamed)
-    had_changes = True
-    while had_changes:
-        had_changes = False
-        for k, v in mapping.items():
-            if v in renamed and renamed[v] != v:
-                renamed[k] = renamed[v]
-                mapping[k] = renamed[k]
-                had_changes = True
-
-    shared_blobs = set(mapping.values())
-
-    if saved_count > 0:
-        log.info("Remapping {} blobs, using {} shared; saved apprx {} MB".format(
-            len(mapping), len(shared_blobs), int(saved_count * 4 / 1024 / 1024),
-        ))
-        log.info("Could not infer sizes for: {}".format(unknown_shapes))
-    else:
-        log.info("Remapping {} blobs, using {} shared".format(
-            len(mapping), len(shared_blobs),
-        ))
-
-    apply_assignments(netproto, mapping)
     log.info("Memonger memory optimization took {} secs".format(
         time.time() - start_time),
     )
-    assert verify_graph_equality(origproto, netproto), \
-        "Memonger graph is not equal to original."
-    assert verify_inplace_blobs(origproto, netproto), \
-        "Inplace assignments differ in memonger net."
-    return netproto
 
+    optim = caffe2_pb2.NetDef()
+    optim.ParseFromString(optim_str)
+
+    assert verify_graph_equality(net.Proto(), optim), \
+        "Memonger graph is not equal to original."
+    assert verify_inplace_blobs(net.Proto(), optim), \
+        "Inplace assignments differ in memonger net."
+    return optim
 
 def _find_source_nodes(g):
     ''' Return nodes without predecessors '''
@@ -927,7 +684,8 @@
 def optimize_inference_fast(net, static_blobs):
     optim = caffe2_pb2.NetDef()
     optim_str = C.memonger_optimize_inference_net(
-        net.SerializeToString(), [str(s).encode('utf-8') for s in static_blobs]
+        net.SerializeToString(),
+        [str(s).encode('utf-8') for s in static_blobs]
     )
     optim.ParseFromString(optim_str)
     return optim
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index d8f7b08..92d6881 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -907,9 +907,35 @@
         return py::bytes(protob);
       });
   m.def(
+      "memonger_compute_blob_recycling_for_dag",
+      [](const py::bytes& net_def,
+         const std::vector<string>& input_blobs,
+         const std::vector<int>& op_indices,
+         const std::unordered_set<string>& shareable_blob_names,
+         const string& namescope,
+         const std::unordered_set<string>& dont_share_blob_names,
+         const std::unordered_map<string, vector<int>>& blob_shapes) {
+        py::gil_scoped_release g;
+        NetDef net;
+        CAFFE_ENFORCE(
+            ParseProtobufFromLargeString(net_def.cast<std::string>(), &net));
+        NetDef optimized_proto =
+            caffe2::memonger::compute_blob_recycling_for_dag(
+                net,
+                input_blobs,
+                op_indices,
+                shareable_blob_names,
+                namescope,
+                dont_share_blob_names,
+                blob_shapes);
+        std::string protob;
+        CAFFE_ENFORCE(optimized_proto.SerializeToString(&protob));
+        return py::bytes(protob);
+      });
+  m.def(
       "memonger_optimize_inference_net",
       [](const py::bytes& net_def,
-         const std::vector<std::string> static_blobs) {
+         const std::vector<std::string>& static_blobs) {
         NetDef def;
         CAFFE_ENFORCE(
             ParseProtobufFromLargeString(net_def.cast<std::string>(), &def));