Lazily initialize alias db in remove_mutation opt (#55949)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55949
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D27793881
fbshipit-source-id: eebde5b5142d8fecfee4756604d313b0da809882
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index c5d9bba..1f7a7b3 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -207,7 +207,7 @@
writeRegistry_ = nullptr;
// initialize the write cache
- writtenToLocationsIndex_ = buildWrittenToLocationsIndex();
+ buildWrittenToLocationsIndex();
GRAPH_DEBUG(toString());
}
@@ -1694,13 +1694,13 @@
return *maybe_wildcardElement;
}
-MemoryLocations AliasDb::buildWrittenToLocationsIndex() const {
+void AliasDb::buildWrittenToLocationsIndex() {
MemoryLocations ret;
for (const auto& pr : *writeIndex_) {
const auto& writtenLocs = pr.second;
ret |= writtenLocs;
}
- return ret;
+ writtenToLocationsIndex_ = ret;
}
void Lint(const AliasDb* db) {
diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h
index 3e90836..cd888ad 100644
--- a/torch/csrc/jit/ir/alias_analysis.h
+++ b/torch/csrc/jit/ir/alias_analysis.h
@@ -269,7 +269,7 @@
c10::optional<TWriteIndex> writeIndex_;
// Collection of all memory locations that are written to.
c10::optional<MemoryLocations> writtenToLocationsIndex_;
- MemoryLocations buildWrittenToLocationsIndex() const;
+ void buildWrittenToLocationsIndex();
std::unordered_set<const Value*> wildcards_;
diff --git a/torch/csrc/jit/passes/remove_mutation.cpp b/torch/csrc/jit/passes/remove_mutation.cpp
index 33ec923..e562f82 100644
--- a/torch/csrc/jit/passes/remove_mutation.cpp
+++ b/torch/csrc/jit/passes/remove_mutation.cpp
@@ -21,7 +21,7 @@
// if the output isn't contained or alias by the inputs to its node, it's
// unique
return !unhandled_node &&
- !aliasDb_->mayContainAlias(v->node()->inputs(), v) &&
+ !getOrCreateAliasDb()->mayContainAlias(v->node()->inputs(), v) &&
!(v->node()->kind() == prim::Param);
}
@@ -86,7 +86,7 @@
// In order to safely remove a mutation, the creation of a tensor and its
// subsequent mutation need to be one atomic operation
- return aliasDb_->moveBeforeTopologicallyValid(
+ return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
mutated_value->node(), mutating_op);
}
@@ -119,7 +119,8 @@
return false;
}
- return aliasDb_->moveBeforeTopologicallyValid(if_node, mutating_op);
+ return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
+ if_node, mutating_op);
}
bool MutationRemover::RemoveListMutation(Block* block) {
@@ -175,14 +176,13 @@
bool has_output = (node->outputs().size() > 0);
if (has_output) {
node->output()->replaceAllUsesWith(mutated_value);
- aliasDb_->writeIndex_->erase(node);
+ getOrCreateAliasDb()->writeIndex_->erase(node);
}
node->destroy();
// TODO: don't strictly need to reset write cache, evaluate on models
- aliasDb_->writtenToLocationsIndex_ =
- aliasDb_->buildWrittenToLocationsIndex();
+ getOrCreateAliasDb()->buildWrittenToLocationsIndex();
}
return changed;
@@ -254,21 +254,21 @@
// same aliasing relationships as the original x.
// To avoid rebuilding the entire alias db, we can replace
// the memory dag element of x with x0.
- aliasDb_->replaceWithNewValue(mutated_value, new_node->output());
+ getOrCreateAliasDb()->replaceWithNewValue(
+ mutated_value, new_node->output());
// it is an invariant that all mutable types have an element in the memory
// dag so we must regive x an alias db element. We have already verified
// that the mutated value is a fresh alias with a single use.
- aliasDb_->createValue(mutated_value);
+ getOrCreateAliasDb()->createValue(mutated_value);
// We must erase the destroyed node from the AliasDb lists of writes
- aliasDb_->writeIndex_->erase(node);
+ getOrCreateAliasDb()->writeIndex_->erase(node);
node->destroy();
// now that we have removed a mutating op, the write cache is stale
// TODO: don't strictly need to reset write cache, evaluate on models
- aliasDb_->writtenToLocationsIndex_ =
- aliasDb_->buildWrittenToLocationsIndex();
+ getOrCreateAliasDb()->buildWrittenToLocationsIndex();
}
return changed;
diff --git a/torch/csrc/jit/passes/remove_mutation.h b/torch/csrc/jit/passes/remove_mutation.h
index 7488968..0e69f36 100644
--- a/torch/csrc/jit/passes/remove_mutation.h
+++ b/torch/csrc/jit/passes/remove_mutation.h
@@ -14,7 +14,6 @@
std::shared_ptr<Graph> graph,
c10::optional<std::function<bool(Node*)>> mutation_filter = c10::nullopt)
: aliasDb_(nullptr), graph_(std::move(graph)) {
- aliasDb_ = torch::make_unique<AliasDb>(graph_);
mutation_filter_ = mutation_filter;
}
@@ -63,8 +62,8 @@
return false;
}
auto inputs = n->inputs();
- if (!aliasDb_->writesToAlias(n, {inputs.at(0)}) ||
- aliasDb_->writesToAlias(
+ if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) ||
+ getOrCreateAliasDb()->writesToAlias(
n, {inputs.slice(1).begin(), inputs.slice(1).end()})) {
return false;
}
@@ -88,6 +87,13 @@
// return true if graph is modified
bool RemoveTensorMutation(Block* block);
+ AliasDb* getOrCreateAliasDb() {
+ if (!aliasDb_) {
+ aliasDb_ = std::make_unique<AliasDb>(graph_);
+ }
+ return aliasDb_.get();
+ }
+
c10::optional<std::function<bool(Node*)>> mutation_filter_;
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::shared_ptr<Graph> graph_;