Lazily initialize alias db in remove_mutation opt (#55949)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55949

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D27793881

fbshipit-source-id: eebde5b5142d8fecfee4756604d313b0da809882
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index c5d9bba..1f7a7b3 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -207,7 +207,7 @@
   writeRegistry_ = nullptr;
 
   // initialize the write cache
-  writtenToLocationsIndex_ = buildWrittenToLocationsIndex();
+  buildWrittenToLocationsIndex();
   GRAPH_DEBUG(toString());
 }
 
@@ -1694,13 +1694,13 @@
   return *maybe_wildcardElement;
 }
 
-MemoryLocations AliasDb::buildWrittenToLocationsIndex() const {
+void AliasDb::buildWrittenToLocationsIndex() {
   MemoryLocations ret;
   for (const auto& pr : *writeIndex_) {
     const auto& writtenLocs = pr.second;
     ret |= writtenLocs;
   }
-  return ret;
+  writtenToLocationsIndex_ = ret;
 }
 
 void Lint(const AliasDb* db) {
diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h
index 3e90836..cd888ad 100644
--- a/torch/csrc/jit/ir/alias_analysis.h
+++ b/torch/csrc/jit/ir/alias_analysis.h
@@ -269,7 +269,7 @@
   c10::optional<TWriteIndex> writeIndex_;
   // Collection of all memory locations that are written to.
   c10::optional<MemoryLocations> writtenToLocationsIndex_;
-  MemoryLocations buildWrittenToLocationsIndex() const;
+  void buildWrittenToLocationsIndex();
 
   std::unordered_set<const Value*> wildcards_;
 
diff --git a/torch/csrc/jit/passes/remove_mutation.cpp b/torch/csrc/jit/passes/remove_mutation.cpp
index 33ec923..e562f82 100644
--- a/torch/csrc/jit/passes/remove_mutation.cpp
+++ b/torch/csrc/jit/passes/remove_mutation.cpp
@@ -21,7 +21,7 @@
   // if the output isn't contained or alias by the inputs to its node, it's
   // unique
   return !unhandled_node &&
-      !aliasDb_->mayContainAlias(v->node()->inputs(), v) &&
+      !getOrCreateAliasDb()->mayContainAlias(v->node()->inputs(), v) &&
       !(v->node()->kind() == prim::Param);
 }
 
@@ -86,7 +86,7 @@
 
   // In order to safely remove a mutation, the creation of a tensor and its
   // subsequent mutation need to be one atomic operation
-  return aliasDb_->moveBeforeTopologicallyValid(
+  return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
       mutated_value->node(), mutating_op);
 }
 
@@ -119,7 +119,8 @@
     return false;
   }
 
-  return aliasDb_->moveBeforeTopologicallyValid(if_node, mutating_op);
+  return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
+      if_node, mutating_op);
 }
 
 bool MutationRemover::RemoveListMutation(Block* block) {
@@ -175,14 +176,13 @@
     bool has_output = (node->outputs().size() > 0);
     if (has_output) {
       node->output()->replaceAllUsesWith(mutated_value);
-      aliasDb_->writeIndex_->erase(node);
+      getOrCreateAliasDb()->writeIndex_->erase(node);
     }
 
     node->destroy();
 
     // TODO: don't strictly need to reset write cache, evaluate on models
-    aliasDb_->writtenToLocationsIndex_ =
-        aliasDb_->buildWrittenToLocationsIndex();
+    getOrCreateAliasDb()->buildWrittenToLocationsIndex();
   }
 
   return changed;
@@ -254,21 +254,21 @@
     // same aliasing relationships as the original x.
     // To avoid rebuilding the entire alias db, we can replace
     // the memory dag element of x with x0.
-    aliasDb_->replaceWithNewValue(mutated_value, new_node->output());
+    getOrCreateAliasDb()->replaceWithNewValue(
+        mutated_value, new_node->output());
 
     // it is an invariant that all mutable types have an element in the memory
     // dag so we must regive x an alias db element. We have already verified
     // that the mutated value is a fresh alias with a single use.
-    aliasDb_->createValue(mutated_value);
+    getOrCreateAliasDb()->createValue(mutated_value);
 
     // We must erase the destroyed node from the AliasDb lists of writes
-    aliasDb_->writeIndex_->erase(node);
+    getOrCreateAliasDb()->writeIndex_->erase(node);
     node->destroy();
 
     // now that we have removed a mutating op, the write cache is stale
     // TODO: don't strictly need to reset write cache, evaluate on models
-    aliasDb_->writtenToLocationsIndex_ =
-        aliasDb_->buildWrittenToLocationsIndex();
+    getOrCreateAliasDb()->buildWrittenToLocationsIndex();
   }
 
   return changed;
diff --git a/torch/csrc/jit/passes/remove_mutation.h b/torch/csrc/jit/passes/remove_mutation.h
index 7488968..0e69f36 100644
--- a/torch/csrc/jit/passes/remove_mutation.h
+++ b/torch/csrc/jit/passes/remove_mutation.h
@@ -14,7 +14,6 @@
       std::shared_ptr<Graph> graph,
       c10::optional<std::function<bool(Node*)>> mutation_filter = c10::nullopt)
       : aliasDb_(nullptr), graph_(std::move(graph)) {
-    aliasDb_ = torch::make_unique<AliasDb>(graph_);
     mutation_filter_ = mutation_filter;
   }
 
@@ -63,8 +62,8 @@
       return false;
     }
     auto inputs = n->inputs();
-    if (!aliasDb_->writesToAlias(n, {inputs.at(0)}) ||
-        aliasDb_->writesToAlias(
+    if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) ||
+        getOrCreateAliasDb()->writesToAlias(
             n, {inputs.slice(1).begin(), inputs.slice(1).end()})) {
       return false;
     }
@@ -88,6 +87,13 @@
   // return true if graph is modified
   bool RemoveTensorMutation(Block* block);
 
+  AliasDb* getOrCreateAliasDb() {
+    if (!aliasDb_) {
+      aliasDb_ = std::make_unique<AliasDb>(graph_);
+    }
+    return aliasDb_.get();
+  }
+
   c10::optional<std::function<bool(Node*)>> mutation_filter_;
   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
   std::shared_ptr<Graph> graph_;