cleanups to alias analysis (#21221)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21221
ghimport-source-id: 778e7317bbe874d35a903d89af5e0bc9721c8680
Reviewed By: jamesr66a
Differential Revision: D15592313
Pulled By: suo
fbshipit-source-id: d6f6d2be8cd80b40dd26d0bb3be30f074e356105
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index 024383f..569941a 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -338,25 +338,10 @@
if (tryRegisteredAnalysis(node)) {
return;
}
- AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
+ TORCH_INTERNAL_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
}
const auto& schema = node->schema();
- if (schema.is_vararg() || schema.is_varret()) {
- const auto hasMutableOutputs = std::any_of(
- node->outputs().cbegin(),
- node->outputs().cend(),
- [](const Value* output) { return shouldAnnotate(output); });
-
- // We don't have alias info for this node. Either schematize it, or
- // add it an analyze* method for it.
- if (hasMutableOutputs) {
- throw script::ErrorReport(node->sourceRange())
- << "Alias information not found for node. File a bug report.\n"
- << "Node: " << *node << "\n";
- }
- }
-
// see [custom operator aliasing]
if (!node->kind().is_aten() && !node->kind().is_prim()) {
return analyzeConservative(node);
@@ -379,10 +364,12 @@
}
// Do sanity checks on the alias annotation
- // - We don't support composite types for alias analysis yet.
- AT_ASSERT(formal->containedTypes().size() == 0);
- // - Doesn't make sense for a value to start annotated as a wildcard.
- AT_ASSERT(!formal->isWildcardBefore());
+ TORCH_INTERNAL_ASSERT(
+ formal->containedTypes().size() == 0,
+ "Composite types for alias analysis not yet supported");
+ TORCH_INTERNAL_ASSERT(
+ !formal->isWildcardBefore(),
+ "Doesn't make sense for a input value to begin as a wildcard");
const auto& formalAlias = formal->beforeSet();
@@ -401,11 +388,15 @@
// Now deal with sets after the '->'
if (formal->isWildcardAfter()) {
+ TORCH_INTERNAL_ASSERT(
+ formal->afterSets().size() == 1,
+ "If the after set contains a wildcard, "
+ "there should be no other alias sets specified.");
setWildcard(actualValue);
} else {
// We don't understand anything else in the after yet, so assert there's
// been no change.
- AT_ASSERT(formal->beforeSets() == formal->afterSets());
+ TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
}
}
@@ -424,10 +415,16 @@
continue;
}
- // We don't support composite types for alias analysis yet.
- AT_ASSERT(formal->containedTypes().size() == 0);
+ TORCH_INTERNAL_ASSERT(
+ formal->containedTypes().size() == 0,
+ "Composite types for alias analysis not yet supported");
+ TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
- if (formal->isWildcardBefore() || formal->isWildcardAfter()) {
+ if (formal->isWildcardBefore()) {
+ TORCH_INTERNAL_ASSERT(
+ formal->beforeSets().size() == 1,
+ "If an output is a wildcard, "
+ "there should be no other alias sets specified.");
setWildcard(actual);
continue;
}
@@ -493,8 +490,8 @@
const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
- AT_ASSERT(loopCarriedInputs.size() == blockInputs.size());
- AT_ASSERT(blockOutputs.size() == node->outputs().size());
+ TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
+ TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
// Run alias analysis on the loop body, iterating until the block output
// alias info converges.
@@ -521,10 +518,11 @@
analyze(subgraphBlock);
- // TODO(suo): the subgraph outputs and node outputs are NOT NECESSARILY the
+ // Note: the subgraph outputs and node outputs are NOT NECESSARILY the
// same length. Autodifferentiation maybe capture additional outputs in the
// subgraph block.
- AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
+ TORCH_INTERNAL_ASSERT(
+ subgraphBlock->outputs().size() >= node->outputs().size());
for (size_t i = 0; i < node->outputs().size(); i++) {
makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
}
@@ -600,7 +598,7 @@
// SetAttr: writes to the `self` field
void AliasDb::analyzeSetAttr(Node* node) {
const auto self = node->inputs().at(0);
- AT_ASSERT(self->type()->kind() == TypeKind::ClassType);
+ TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
registerWrite(self, node);
// Also the value being set must become a wildcard.
const auto newValue = node->inputs().at(1);
@@ -615,8 +613,6 @@
setWildcard(input);
}
- // TODO(suo): we can make the more refined assumption that outputs may only
- // alias any input.
for (const auto output : node->outputs()) {
setWildcard(output);
}
@@ -627,7 +623,7 @@
// TODO: tuples are treated differently since we actually compare the contained
// values for aliasing, so we don't need wildcards.
void AliasDb::analyzeContainerConstruct(Node* node) {
- AT_ASSERT(
+ TORCH_INTERNAL_ASSERT(
node->kind() == prim::ListConstruct ||
node->kind() == prim::DictConstruct);
@@ -658,7 +654,7 @@
// Register the fact that `from` is a pointer to `to`
void AliasDb::makePointerTo(const Value* from, const Value* to) {
if (!shouldAnnotate(from)) {
- AT_ASSERT(!shouldAnnotate(to));
+ TORCH_INTERNAL_ASSERT(!shouldAnnotate(to));
return;
}
@@ -674,7 +670,7 @@
}
// At this point, we should be dealing with two mutable types.
- AT_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
+ TORCH_INTERNAL_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
auto fromEl = getOrCreateElement(from);
auto toEl = getOrCreateElement(to);
@@ -689,7 +685,7 @@
return;
}
- AT_ASSERT(isContainerType(container->type()));
+ TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
auto elemEl = getOrCreateElement(elem);
auto contEl = getOrCreateElement(container);
@@ -757,7 +753,7 @@
// Make each value in the `from` list point to its partner in the `to` list
void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
- AT_ASSERT(to.size() == from.size());
+ TORCH_INTERNAL_ASSERT(to.size() == from.size());
for (size_t i = 0; i < to.size(); i++) {
makePointerTo(from[i], to[i]);
}
@@ -928,7 +924,7 @@
// outside), then return nullptr. Since we can only reorder nodes within a
// block, `target` would be irrelevant.
static Node* findSameBlock(Node* target, Node* n) {
- AT_ASSERT(target->owningGraph() == n->owningGraph());
+ TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
if (target->owningBlock() == n->owningBlock()) {
return target;
} else {
@@ -969,7 +965,7 @@
Node* movePoint,
MoveSide moveSide,
bool dryRun) {
- AT_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
+ TORCH_INTERNAL_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
if (toMove == movePoint) {
return true;
}
@@ -1033,7 +1029,7 @@
}
// 3. Execute the move
- AT_ASSERT(curNode == movePoint);
+ TORCH_INTERNAL_ASSERT(curNode == movePoint);
if (splitToMoveAndDeps) {
// Move `toMove`
move(toMove, movePoint, moveSide);
@@ -1113,12 +1109,12 @@
prim::GetAttr,
prim::SetAttr,
prim::profile,
+ prim::Print,
aten::wait,
};
// Operators that should not be used by alias analysis
const static std::unordered_set<Symbol> purposefully_not_handled = {
- prim::Print,
prim::Load,
prim::Store,
prim::Drop,