cleanups to alias analysis (#21221)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21221
ghimport-source-id: 778e7317bbe874d35a903d89af5e0bc9721c8680

Reviewed By: jamesr66a

Differential Revision: D15592313

Pulled By: suo

fbshipit-source-id: d6f6d2be8cd80b40dd26d0bb3be30f074e356105
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index 024383f..569941a 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -338,25 +338,10 @@
       if (tryRegisteredAnalysis(node)) {
         return;
       }
-      AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
+      TORCH_INTERNAL_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
   }
 
   const auto& schema = node->schema();
-  if (schema.is_vararg() || schema.is_varret()) {
-    const auto hasMutableOutputs = std::any_of(
-        node->outputs().cbegin(),
-        node->outputs().cend(),
-        [](const Value* output) { return shouldAnnotate(output); });
-
-    // We don't have alias info for this node. Either schematize it, or
-    // add it an analyze* method for it.
-    if (hasMutableOutputs) {
-      throw script::ErrorReport(node->sourceRange())
-          << "Alias information not found for node. File a bug report.\n"
-          << "Node: " << *node << "\n";
-    }
-  }
-
   // see [custom operator aliasing]
   if (!node->kind().is_aten() && !node->kind().is_prim()) {
     return analyzeConservative(node);
@@ -379,10 +364,12 @@
     }
 
     // Do sanity checks on the alias annotation
-    // - We don't support composite types for alias analysis yet.
-    AT_ASSERT(formal->containedTypes().size() == 0);
-    // - Doesn't make sense for a value to start annotated as a wildcard.
-    AT_ASSERT(!formal->isWildcardBefore());
+    TORCH_INTERNAL_ASSERT(
+        formal->containedTypes().size() == 0,
+        "Composite types for alias analysis not yet supported");
+    TORCH_INTERNAL_ASSERT(
+        !formal->isWildcardBefore(),
+        "Doesn't make sense for a input value to begin as a wildcard");
 
     const auto& formalAlias = formal->beforeSet();
 
@@ -401,11 +388,15 @@
 
     // Now deal with sets after the '->'
     if (formal->isWildcardAfter()) {
+      TORCH_INTERNAL_ASSERT(
+          formal->afterSets().size() == 1,
+          "If the after set contains a wildcard, "
+          "there should be no other alias sets specified.");
       setWildcard(actualValue);
     } else {
       // We don't understand anything else in the after yet, so assert there's
       // been no change.
-      AT_ASSERT(formal->beforeSets() == formal->afterSets());
+      TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
     }
   }
 
@@ -424,10 +415,16 @@
       continue;
     }
 
-    // We don't support composite types for alias analysis yet.
-    AT_ASSERT(formal->containedTypes().size() == 0);
+    TORCH_INTERNAL_ASSERT(
+        formal->containedTypes().size() == 0,
+        "Composite types for alias analysis not yet supported");
+    TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
 
-    if (formal->isWildcardBefore() || formal->isWildcardAfter()) {
+    if (formal->isWildcardBefore()) {
+      TORCH_INTERNAL_ASSERT(
+          formal->beforeSets().size() == 1,
+          "If an output is a wildcard, "
+          "there should be no other alias sets specified.");
       setWildcard(actual);
       continue;
     }
@@ -493,8 +490,8 @@
   const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
   const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
   const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
-  AT_ASSERT(loopCarriedInputs.size() == blockInputs.size());
-  AT_ASSERT(blockOutputs.size() == node->outputs().size());
+  TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
+  TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
 
   // Run alias analysis on the loop body, iterating until the block output
   // alias info converges.
@@ -521,10 +518,11 @@
 
   analyze(subgraphBlock);
 
-  // TODO(suo): the subgraph outputs and node outputs are NOT NECESSARILY the
+  // Note: the subgraph outputs and node outputs are NOT NECESSARILY the
   // same length. Autodifferentiation maybe capture additional outputs in the
   // subgraph block.
-  AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
+  TORCH_INTERNAL_ASSERT(
+      subgraphBlock->outputs().size() >= node->outputs().size());
   for (size_t i = 0; i < node->outputs().size(); i++) {
     makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
   }
@@ -600,7 +598,7 @@
 // SetAttr: writes to the `self` field
 void AliasDb::analyzeSetAttr(Node* node) {
   const auto self = node->inputs().at(0);
-  AT_ASSERT(self->type()->kind() == TypeKind::ClassType);
+  TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
   registerWrite(self, node);
   // Also the value being set must become a wildcard.
   const auto newValue = node->inputs().at(1);
@@ -615,8 +613,6 @@
     setWildcard(input);
   }
 
-  // TODO(suo): we can make the more refined assumption that outputs may only
-  // alias any input.
   for (const auto output : node->outputs()) {
     setWildcard(output);
   }
@@ -627,7 +623,7 @@
 // TODO: tuples are treated differently since we actually compare the contained
 // values for aliasing, so we don't need wildcards.
 void AliasDb::analyzeContainerConstruct(Node* node) {
-  AT_ASSERT(
+  TORCH_INTERNAL_ASSERT(
       node->kind() == prim::ListConstruct ||
       node->kind() == prim::DictConstruct);
 
@@ -658,7 +654,7 @@
 // Register the fact that `from` is a pointer to `to`
 void AliasDb::makePointerTo(const Value* from, const Value* to) {
   if (!shouldAnnotate(from)) {
-    AT_ASSERT(!shouldAnnotate(to));
+    TORCH_INTERNAL_ASSERT(!shouldAnnotate(to));
     return;
   }
 
@@ -674,7 +670,7 @@
   }
 
   // At this point, we should be dealing with two mutable types.
-  AT_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
+  TORCH_INTERNAL_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
 
   auto fromEl = getOrCreateElement(from);
   auto toEl = getOrCreateElement(to);
@@ -689,7 +685,7 @@
     return;
   }
 
-  AT_ASSERT(isContainerType(container->type()));
+  TORCH_INTERNAL_ASSERT(isContainerType(container->type()));
 
   auto elemEl = getOrCreateElement(elem);
   auto contEl = getOrCreateElement(container);
@@ -757,7 +753,7 @@
 
 // Make each value in the `from` list point to its partner in the `to` list
 void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
-  AT_ASSERT(to.size() == from.size());
+  TORCH_INTERNAL_ASSERT(to.size() == from.size());
   for (size_t i = 0; i < to.size(); i++) {
     makePointerTo(from[i], to[i]);
   }
@@ -928,7 +924,7 @@
   // outside), then return nullptr. Since we can only reorder nodes within a
   // block, `target` would be irrelevant.
   static Node* findSameBlock(Node* target, Node* n) {
-    AT_ASSERT(target->owningGraph() == n->owningGraph());
+    TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
     if (target->owningBlock() == n->owningBlock()) {
       return target;
     } else {
@@ -969,7 +965,7 @@
     Node* movePoint,
     MoveSide moveSide,
     bool dryRun) {
-  AT_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
+  TORCH_INTERNAL_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
   if (toMove == movePoint) {
     return true;
   }
@@ -1033,7 +1029,7 @@
   }
 
   // 3. Execute the move
-  AT_ASSERT(curNode == movePoint);
+  TORCH_INTERNAL_ASSERT(curNode == movePoint);
   if (splitToMoveAndDeps) {
     // Move `toMove`
     move(toMove, movePoint, moveSide);
@@ -1113,12 +1109,12 @@
       prim::GetAttr,
       prim::SetAttr,
       prim::profile,
+      prim::Print,
       aten::wait,
   };
 
   // Operators that should not be used by alias analysis
   const static std::unordered_set<Symbol> purposefully_not_handled = {
-      prim::Print,
       prim::Load,
       prim::Store,
       prim::Drop,