[JIT] Functional Graph Pass (#33020)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33020

This is a pass to create functional blocks. The other PRs in the stack help avoid some of the limitations that are are often found in graphs. It's possible that this would work well with a graph that is frozen. Follow up work items that will help this pass:

- We don't currently have any capacity in alias analysis to tell whether a Value that came from the wildcard set "re-escapes" back into the wildcard set.
- More comments on the semantics of the graph and correctness conditions
- We could consider using dynamic dag if the perf of this is a limitation.
- potential make Functional Graphs Functional Blocks instead, so that we do not repeatedly copy constants, also to make IR read easier.

Test Plan: Imported from OSS

Differential Revision: D20603188

Pulled By: eellison

fbshipit-source-id: 6822a6e65f4cc2676f8f6445fe8aa1cb858ebeeb
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 32cdffc..2a325c9 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -34,6 +34,7 @@
   _(prim, Expand) /* onnx */         \
   _(prim, FusionGroup)               \
   _(prim, CudaFusionGroup)           \
+  _(prim, FunctionalGraph)           \
   _(prim, DifferentiableGraph)       \
   _(prim, If)                        \
   _(prim, Jump) /* debug */          \
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index b9dec0b..85cb135 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -399,6 +399,7 @@
     ${TORCH_SRC_DIR}/csrc/jit/passes/constant_pooling.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/common_subexpression_elimination.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/create_autodiff_subgraphs.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/passes/create_functional_graphs.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/insert_guards.cpp
     ${TORCH_SRC_DIR}/csrc/jit/passes/inliner.cpp
diff --git a/test/jit/test_functional_blocks.py b/test/jit/test_functional_blocks.py
new file mode 100644
index 0000000..cb6b369
--- /dev/null
+++ b/test/jit/test_functional_blocks.py
@@ -0,0 +1,41 @@
+import os
+import sys
+
+import torch
+from torch.testing import FileCheck
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from torch.testing._internal.jit_utils import JitTestCase
+
+if __name__ == '__main__':
+    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
+                       "\tpython test/test_jit.py TESTNAME\n\n"
+                       "instead.")
+
+class TestFunctionalBlocks(JitTestCase):
+    def test_simple_no_merge(self):
+        def fn(x, y, z):
+            x = x + 1
+            y = y + 1
+            z = z + 1
+            z.add_(2)
+            z = z * z
+            y = y * z
+            if y < 2:
+                y = y + 5
+            return x + y + z
+
+        graph = torch.jit.script(fn).graph
+        self.run_pass('create_functional_graphs', graph)
+
+        # all uses of x and y should be sunk
+        FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph)
+        FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph)
+
+        # Don't allow any outputs which escape scope, so there is one final addition in the graph
+        FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph)
+
+        # z + 1, z.add_(2) z * z considered non functional
+        FileCheck().check("add").check("add_").check("mul").check("FunctionalGraph").run(graph)
diff --git a/test/test_jit.py b/test/test_jit.py
index 762e9a2..88a710b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -23,6 +23,7 @@
 from jit.test_builtins import TestBuiltins, TestTensorBuiltins  # noqa: F401
 from jit.test_unsupported_ops import TestUnsupportedOps  # noqa: F401
 from jit.test_freezing import TestFreezing  # noqa: F401
+from jit.test_functional_blocks import TestFunctionalBlocks  # noqa: F401
 
 # Torch
 from torch import Tensor
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index de4a153..9ce1fb0 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -117,6 +117,7 @@
     "torch/csrc/jit/passes/constant_propagation.cpp",
     "torch/csrc/jit/passes/constant_pooling.cpp",
     "torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",
+    "torch/csrc/jit/passes/create_functional_graphs.cpp",
     "torch/csrc/jit/passes/dead_code_elimination.cpp",
     "torch/csrc/jit/passes/erase_number_types.cpp",
     "torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp",
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 957b254..ef9cb6d 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -327,6 +327,7 @@
       return analyzeLoop(node);
     case prim::FusionGroup:
     case prim::CudaFusionGroup:
+    case prim::FunctionalGraph:
     case prim::DifferentiableGraph:
       return analyzeSubgraph(node);
     case prim::fork:
diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h
index 67b8f95..82b6469 100644
--- a/torch/csrc/jit/ir/alias_analysis.h
+++ b/torch/csrc/jit/ir/alias_analysis.h
@@ -87,6 +87,8 @@
   // reads from.
   TORCH_API bool isMutable(Node* n) const;
 
+  TORCH_API bool escapesScope(const at::ArrayRef<Value*>& vs) const;
+
   // Is it safe to change whether `a` and `b` alias each other ?
   TORCH_API bool safeToChangeAliasingRelationship(
       const at::ArrayRef<Value*>& a,
@@ -148,8 +150,6 @@
   // Is this a value which will not alias
   bool nonAliasingValue(const Value* elem) const;
 
-  bool escapesScope(const at::ArrayRef<Value*>& vs) const;
-
   /**
    * Special analysis methods
    */
diff --git a/torch/csrc/jit/passes/create_functional_graphs.cpp b/torch/csrc/jit/passes/create_functional_graphs.cpp
new file mode 100644
index 0000000..8f76b5d
--- /dev/null
+++ b/torch/csrc/jit/passes/create_functional_graphs.cpp
@@ -0,0 +1,235 @@
+#include <torch/csrc/jit/passes/create_functional_graphs.h>
+#include <c10/util/Exception.h>
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
+#include <torch/csrc/utils/memory.h>
+
+
+#include <cstddef>
+#include <limits>
+
+namespace torch {
+namespace jit {
+
+namespace {
+
+struct FunctionalGraphSlicer {
+  FunctionalGraphSlicer(std::shared_ptr<Graph> graph)
+      : graph_(std::move(graph)) {}
+
+  void run() {
+    bool changed = true;
+    // TODO: more sane strategy
+    size_t MAX_NUM_ITERATIONS = 4;
+
+    // First, analyze the functional subset of the graph, and then create
+    // functional graphs. The graph gets mutated when we create functional
+    // subgraphs, invalidating the AliasDb, so we need to do our analysis
+    // first.
+    for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) {
+      aliasDb_ = torch::make_unique<AliasDb>(graph_);
+      AnalyzeFunctionalSubset(graph_->block());
+      changed = CreateFunctionalGraphsImpl(graph_->block());
+    }
+  }
+
+ private:
+  bool isEmptyFunctionalGraph(Node* n) {
+    auto g = n->g(attr::Subgraph);
+    return g->inputs().size() == 0 && g->outputs().size() == 0;
+  }
+
+  void nonConstNodes(Block* block, size_t* num) {
+    for (auto it = block->nodes().begin();
+         it != block->nodes().end() && *num < minSubgraphSize_;
+         ++it) {
+      Node* n = *it;
+      if (n->kind() == prim::Constant) {
+        continue;
+      }
+      *num = *num + 1;
+      for (Block* b : n->blocks()) {
+        nonConstNodes(b, num);
+      }
+    }
+  }
+
+  bool inlineIfTooSmall(Node* n) {
+    AT_ASSERT(n->kind() == prim::FunctionalGraph);
+    auto subgraph = SubgraphUtils::getSubgraph(n);
+    size_t num_modes = 0;
+    nonConstNodes(subgraph->block(), &num_modes);
+    if (num_modes < minSubgraphSize_) {
+      SubgraphUtils::unmergeSubgraph(n);
+      return true;
+    }
+    return false;
+  }
+
+  bool CreateFunctionalGraphsImpl(Block* block) {
+    /*
+    Iterate the block in reverse and create FunctionalSubgraphs.
+    When we encounter a node that isn't functional, we skip it. Otherwise,
+    we try to merge the functional node into the current functional subgraph.
+    If it can't be merged into the current functional subgraph node, then we
+    start a functional subgraph group.
+    */
+    bool changed = false;
+    std::vector<Node*> functional_graph_nodes;
+
+    Node* functional_subgraph_node =
+        graph_->createWithSubgraph(prim::FunctionalGraph)
+            ->insertBefore(block->return_node());
+    auto reverse_iter = block->nodes().reverse();
+    std::vector<Value*> graph_outputs;
+    for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
+      Node* n = *it++;
+
+      // constants get copied into the graph
+      if (n->kind() == prim::Constant || n == functional_subgraph_node) {
+        continue;
+      }
+
+      // if `n` is functional, all of its blocks will be merged into the
+      // new functional subgraph, so we only need to recurse if it is not
+      // functional
+      if (!functional_nodes_.count(n)) {
+        for (Block* b : n->blocks()) {
+          auto block_changed = CreateFunctionalGraphsImpl(b);
+          changed = block_changed && changed;
+        }
+        continue;
+      }
+
+      if (n->kind() == prim::FunctionalGraph &&
+          isEmptyFunctionalGraph(functional_subgraph_node)) {
+        functional_subgraph_node->destroy();
+        functional_subgraph_node = n;
+        continue;
+      }
+
+      changed = true;
+      if (aliasDb_->moveBeforeTopologicallyValid(n, functional_subgraph_node)) {
+        SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
+      } else {
+        functional_graph_nodes.emplace_back(functional_subgraph_node);
+        functional_subgraph_node =
+            graph_->createWithSubgraph(prim::FunctionalGraph)->insertAfter(n);
+        SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
+      }
+    }
+    functional_graph_nodes.emplace_back(functional_subgraph_node);
+
+    for (Node* functional_node : functional_graph_nodes) {
+      if (!inlineIfTooSmall(functional_node)) {
+        ConstantPooling(functional_node->g(attr::Subgraph));
+      }
+    }
+    return changed;
+  }
+
+  bool AnalyzeFunctionalSubset(Node* n) {
+    // TODO: clarify hasSideEffects, isNondeterministic
+    bool is_functional_node = true;
+
+    // Functional Graphs are not responsible for maintaining aliasing
+    // relationships. If an output of a functional graph escapes scope
+    // or is mutated then we might change semantics of the program if
+    // aliasing relationships are changed.
+    // For now, we don't allow any values which are mutated into the functional
+    // graph, and we don't allow any nodes which have outputs that escape scope.
+    // Possible Future Improvements:
+    // - allow inputs to have mutations so long as there are no mutations in the
+    // graph
+    // - allow functional graphs to have at most one value that can escape scope
+    // - allow outputs which alias the wildcard set but do not "re-escape"
+    for (Value* v : n->outputs()) {
+      bool has_writers = aliasDb_->hasWriters(v);
+      bool escapes_scope = aliasDb_->escapesScope(v);
+      if (has_writers) {
+        mutated_values_.insert(v);
+      }
+      is_functional_node = is_functional_node && !escapes_scope && !has_writers;
+    }
+
+    for (Block* block : n->blocks()) {
+      auto functional_block = AnalyzeFunctionalSubset(block);
+      is_functional_node = is_functional_node && functional_block;
+    }
+
+    // mutated_values_ already populated with inputs to this node
+    auto inputs = n->inputs();
+    is_functional_node = is_functional_node &&
+        std::all_of(inputs.begin(), inputs.end(), [&](Value* v) {
+                        return !mutated_values_.count(v);
+                      });
+    if (is_functional_node) {
+      functional_nodes_.insert(n);
+    }
+    return is_functional_node;
+  }
+
+  void AnalyzeFunctionalSubset(at::ArrayRef<Block*> blocks) {
+    for (Block* block : blocks) {
+      AnalyzeFunctionalSubset(block);
+    }
+  }
+
+  bool AnalyzeFunctionalSubset(Block* block) {
+    bool is_functional_block = true;
+    // block inputs will not yet have been iterated through,
+    // so we need to add them to our set of mutated & escape values.
+    for (Value* v : block->inputs()) {
+      bool has_writers = aliasDb_->hasWriters(v);
+      if (has_writers) {
+        mutated_values_.insert(v);
+      }
+    }
+    // if a block output is not functional, then the corresponding output for the node 
+    // that contains the block will not be functional either, 
+    // so we do not need to analyze the block outputs here.
+    for (Node* n : block->nodes()) {
+      bool functional = AnalyzeFunctionalSubset(n);
+      is_functional_block = is_functional_block && functional;
+    }
+    return is_functional_block;
+  }
+
+  std::unordered_set<Node*> functional_nodes_;
+  std::unordered_set<Value*> mutated_values_;
+  std::shared_ptr<Graph> graph_;
+  std::unique_ptr<AliasDb> aliasDb_ = nullptr;
+  size_t minSubgraphSize_ = 6;
+};
+
+void InlineFunctionalGraphs(Block* block) {
+  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+    Node* n = *it;
+    it++;
+    for (Block* b : n->blocks()) {
+      InlineFunctionalGraphs(b);
+    }
+    if (n->kind() == prim::FunctionalGraph) {
+      SubgraphUtils::unmergeSubgraph(n);
+    }
+  }
+}
+
+} // namespace
+
+void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
+  // Run Constant Pooling so constants get hoisted
+  ConstantPooling(graph);
+  FunctionalGraphSlicer func(graph);
+  func.run();
+  // Creation of Functional Subgraphs & Deinlining creates excess constants
+  ConstantPooling(graph);
+}
+
+void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
+  InlineFunctionalGraphs(graph->block());
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/create_functional_graphs.h b/torch/csrc/jit/passes/create_functional_graphs.h
new file mode 100644
index 0000000..6084e2f
--- /dev/null
+++ b/torch/csrc/jit/passes/create_functional_graphs.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+TORCH_API void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph);
+
+TORCH_API void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
index 3ccaa25..329e240 100644
--- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp
+++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
@@ -45,8 +45,59 @@
   subgraphNode->destroy();
 }
 
+void collectNestedUses(
+    std::unordered_set<Value*>& closed_over_values,
+    std::unordered_set<Value*>& new_values,
+    std::unordered_map<Value*, Value*>& inputsMap,
+    Node* input_node) {
+  for (auto input : input_node->inputs()) {
+    if (inputsMap.count(input) == 0 && new_values.count(input) == 0) {
+      closed_over_values.insert(input);
+    }
+  }
+  if (input_node->kind() == prim::If) {
+    for (Block* block : input_node->blocks()) {
+      for (Node* node : block->nodes()) {
+        collectNestedUses(closed_over_values, new_values, inputsMap, node);
+      }
+      for (Value* v : block->outputs()) {
+        if (inputsMap.count(v) == 0 && new_values.count(v) == 0) {
+          closed_over_values.insert(v);
+        }
+      }
+    }
+  } else if (input_node->kind() == prim::Loop) {
+    for (Value* v : input_node->inputs()) {
+      if (inputsMap.count(v) == 0 && new_values.count(v) == 0) {
+        closed_over_values.insert(v);
+      }
+    }
+    Block* block = input_node->blocks().at(0);
+    for (Value* v : block->inputs()) {
+      new_values.insert(v);
+    }
+    for (Node* node : block->nodes()) {
+      collectNestedUses(closed_over_values, new_values, inputsMap, node);
+    }
+  } else if (input_node->blocks().size() != 0) {
+    TORCH_INTERNAL_ASSERT(false, input_node, " kind not handled yet");
+  }
+  for (Value* output : input_node->outputs()) {
+    new_values.insert(output);
+  }
+}
+
+std::unordered_set<Value*> closedOverValues(
+    Node* toMerge,
+    std::unordered_map<Value*, Value*>& inputsMap) {
+  std::unordered_set<Value*> closed_over_values;
+  std::unordered_set<Value*> new_values;
+  collectNestedUses(closed_over_values, new_values, inputsMap, toMerge);
+  return closed_over_values;
+}
+
 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
-  AT_ASSERT(hasSubgraph(subgraphNode));
+  AT_ASSERT(hasSubgraph(subgraphNode) && toMerge != subgraphNode);
   if (hasSubgraph(toMerge)) {
     return mergeSubgraph(subgraphNode, toMerge);
   }
@@ -65,7 +116,9 @@
 
   // Add n's inputs to the group's input list if we don't already have them
   WithInsertPoint guard(*subgraph->nodes().begin());
-  for (auto input : toMerge->inputs()) {
+  std::unordered_set<Value*> closedValues =
+      closedOverValues(toMerge, inputsMap);
+  for (auto input : closedValues) {
     if (inputsMap.count(input) == 0) {
       // Clone constants inside the subgraph instead of referencing them, to
       // enable more optimizations
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index ed0c940..45f3daf 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -15,6 +15,7 @@
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
+#include <torch/csrc/jit/passes/create_functional_graphs.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/decompose_ops.h>
 #include <torch/csrc/jit/passes/erase_number_types.h>
@@ -259,6 +260,12 @@
           [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
       .def("_jit_pass_constant_pooling", ConstantPooling)
       .def(
+          "_jit_pass_create_functional_graphs",
+          [](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
+      .def(
+          "_jit_pass_inline_functional_graphs",
+          [](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
+      .def(
           "_jit_pass_peephole",
           [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
             return PeepholeOptimize(g, addmm_fusion_enabled);
diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp
index 4bb80ef..f326d2b 100644
--- a/torch/csrc/jit/runtime/operator.cpp
+++ b/torch/csrc/jit/runtime/operator.cpp
@@ -172,7 +172,8 @@
       prim::AutogradAnyNonZero, // temporarily inserted by autograd
       prim::AutogradAdd, // temporarily inserted by autograd
       prim::ConstantChunk, // optimization pass adds it
-      prim::DifferentiableGraph, // optimization pass adds it
+      prim::DifferentiableGraph, // optimization pass adds it,
+      prim::FunctionalGraph, // optimization pass adds it,
       prim::BroadcastSizes, // optimization pass (fuser) adds it
       prim::ChunkSizes, // optimization pass (fuser) adds it
       prim::Drop, // used in interpreter only
@@ -211,6 +212,7 @@
       prim::FusionGroup,
       prim::CudaFusionGroup,
       prim::DifferentiableGraph,
+      prim::FunctionalGraph,
       prim::Constant,
       prim::Uninitialized,
       prim::DictConstruct,