Change compiler to use Load/Stores, then transform to SSA (#21101)
Summary:
This changes our compiler so it first emits Loads & Stores, and then transforms the graph to SSA in a follow up pass. When a variable is set, we emit a prim::Store, and when a variable is referenced, we emit a prim::Load.
```
a = 1
print(a)
```
becomes:
```
%a.1 : int = prim::Constant[value=1]()
prim::Store[name="a"](%a.1)
%a : int = prim::Load[name="a"]()
prim::Print(%a)
```
In the follow up pass, convertToSSA, the values are turned into SSA form with the Loads & Stores removed. This change will enable breaks and continues because you can transform the graph with the variable naming information still intact.
There are still some remaining jitter and edge cases issues that I have to look through, but I think is still ready for eview.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21101
Differential Revision: D15723353
Pulled By: eellison
fbshipit-source-id: 3269934d4bc24ddaf3a87fdd20620b0f954d83d0
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 58694fd..6611b05 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -93,6 +93,7 @@
_(aten, __round_to_zero_floordiv)\
_(aten, _unwrap_optional) \
_(prim, fork) \
+ _(prim, forkClosure) \
_(prim, RaiseException) \
_(prim, Function) \
_(prim, CreateObject) \
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index a1c508f..9e491d8 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -391,6 +391,8 @@
${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/insert_guards.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inliner.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/lift_closures.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/inline_forked_closures.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/dead_code_elimination.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/decompose_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize_ops.cpp
@@ -423,6 +425,7 @@
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/convert_to_ssa.cpp
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h
index bd8e092..28423f2 100644
--- a/test/cpp/jit/test_alias_analysis.h
+++ b/test/cpp/jit/test_alias_analysis.h
@@ -911,10 +911,13 @@
void testAliasRegistration() {
{
- auto registry = torch::RegisterOperators()
- .op("foo::rand", torch::RegisterOperators::options()
- .catchAllKernel([](at::Tensor) -> at::Tensor { return at::rand({2, 2}); })
- .aliasAnalysis(AliasAnalysisKind::DEFAULT));
+ auto registry = torch::RegisterOperators().op(
+ "foo::rand",
+ torch::RegisterOperators::options()
+ .catchAllKernel([](at::Tensor) -> at::Tensor {
+ return at::rand({2, 2});
+ })
+ .aliasAnalysis(AliasAnalysisKind::DEFAULT));
const auto rand_op = Symbol::fromQualString("foo::rand");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
@@ -924,10 +927,11 @@
ASSERT_TRUE(aliasDb.mayAlias(a, b));
}
{
- auto registry = torch::RegisterOperators()
- .op("foo::pure", torch::RegisterOperators::options()
- .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
- .aliasAnalysis(AliasAnalysisKind::PURE));
+ auto registry = torch::RegisterOperators().op(
+ "foo::pure",
+ torch::RegisterOperators::options()
+ .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
+ .aliasAnalysis(AliasAnalysisKind::PURE));
const auto rand_op = Symbol::fromQualString("foo::pure");
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
diff --git a/test/test_jit.py b/test/test_jit.py
index ee4a68d..58a4915 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -839,21 +839,21 @@
m(x1, y1)
# Check what we collected
- self.assertTrue('x' in value_stats and 'y' in value_stats)
- self.assertTrue('p' in value_stats and 'z' in value_stats)
+ self.assertTrue('x.1' in value_stats and 'y.1' in value_stats)
+ self.assertTrue('p.1' in value_stats and 'z.1' in value_stats)
self.assertEqual(len(value_stats), 5)
- self.assertEqual(len(value_stats['p']), 1)
- self.assertEqual(len(value_stats['z']), 1)
- self.assertEqual(value_stats['p'][0], x1 + y1)
- self.assertEqual(value_stats['z'][0], x1 - y1)
+ self.assertEqual(len(value_stats['p.1']), 1)
+ self.assertEqual(len(value_stats['z.1']), 1)
+ self.assertEqual(value_stats['p.1'][0], x1 + y1)
+ self.assertEqual(value_stats['z.1'][0], x1 - y1)
# Run one more time and check the updated statistics
m(x2, y2)
- self.assertTrue('x' in value_stats and 'y' in value_stats)
- self.assertEqual(len(value_stats['p']), 2)
- self.assertEqual(len(value_stats['z']), 2)
- self.assertEqual(value_stats['p'][1], x2 + y2)
- self.assertEqual(value_stats['z'][1], x2 - y2)
+ self.assertTrue('x.1' in value_stats and 'y.1' in value_stats)
+ self.assertEqual(len(value_stats['p.1']), 2)
+ self.assertEqual(len(value_stats['z.1']), 2)
+ self.assertEqual(value_stats['p.1'][1], x2 + y2)
+ self.assertEqual(value_stats['z.1'][1], x2 - y2)
def test_insert_quantdequant_consecutive_qnodes_script(self):
input_data = torch.ones([1, 1, 5, 5])
@@ -14317,8 +14317,8 @@
def func(x):
return torch.ops.aten.relu(x)
self.assertExpectedInline(canonical(func.graph), '''\
-graph(%x : Tensor):
- %1 : Tensor = aten::relu(%x)
+graph(%x.1 : Tensor):
+ %1 : Tensor = aten::relu(%x.1)
return (%1)
''')
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 3b71242..6b902ef 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -86,6 +86,8 @@
"torch/csrc/jit/passes/guard_elimination.cpp",
"torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp",
"torch/csrc/jit/passes/inliner.cpp",
+ "torch/csrc/jit/passes/lift_closures.cpp",
+ "torch/csrc/jit/passes/inline_forked_closures.cpp",
"torch/csrc/jit/passes/inplace_check.cpp",
"torch/csrc/jit/passes/insert_guards.cpp",
"torch/csrc/jit/passes/loop_unrolling.cpp",
@@ -110,6 +112,7 @@
"torch/csrc/jit/script/edit_distance.cpp",
"torch/csrc/jit/script/logging.cpp",
"torch/csrc/jit/script/final_returns.cpp",
+ "torch/csrc/jit/script/convert_to_ssa.cpp",
"torch/csrc/jit/script/script_type_parser.cpp",
"torch/csrc/jit/script/sugared_value.cpp",
"torch/csrc/jit/script/schema_matching.cpp",
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 7802fba..23dbeee 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -1429,6 +1429,19 @@
return n;
}
+Node* Graph::createStore(const std::string& name, Value* v) {
+ auto n = create(prim::Store, {v}, /*num_outputs*/ 0);
+ n->s_(attr::name, name);
+ return n;
+}
+
+Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
+ auto n = create(prim::Load, {}, /*num_outputs*/ 1);
+ n->s_(attr::name, name);
+ n->output()->setType(type);
+ return n;
+}
+
Value* Graph::insertFunctionCall(
std::shared_ptr<script::Function> callee,
script::MatchedSchema& matched) {
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 21f5bec..7547e34 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1097,6 +1097,8 @@
TORCH_API Value* insertGetAttr(Value* obj, const std::string& field) {
return insertNode(createGetAttr(obj, field))->output();
}
+ TORCH_API Node* createStore(const std::string& name, Value* v);
+ TORCH_API Node* createLoad(const std::string& name, const TypePtr& type);
TORCH_API Value* insertFunctionCall(
std::shared_ptr<script::Function> callee,
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index fe50c3b..cd662c9 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -18,6 +18,7 @@
std::unordered_set<Symbol> skip_list = {
prim::If,
prim::Loop,
+ prim::Function,
prim::Constant,
prim::AutogradZero,
prim::unchecked_unwrap_optional, // TODO remove
@@ -63,7 +64,6 @@
auto graph = n->owningGraph();
WithInsertPoint guard(n);
for (size_t i = 0; i < outputs.size(); ++i) {
-
auto new_output = tryInsertConstant(*graph, outputs[i]);
if (new_output) {
if (outputs[i].isNone()) {
diff --git a/torch/csrc/jit/passes/inline_forked_closures.cpp b/torch/csrc/jit/passes/inline_forked_closures.cpp
new file mode 100644
index 0000000..d5af1f7
--- /dev/null
+++ b/torch/csrc/jit/passes/inline_forked_closures.cpp
@@ -0,0 +1,84 @@
+#include <torch/csrc/jit/passes/inline_forked_closures.h>
+#include <torch/csrc/jit/script/compiler.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// Closure nodes are emitted as a tuple of (function %, context tuple %)
+// Inside the closure the closure is then unpacked so that all closed over
+// values are set. A function closing over a and b would look like:
+// def foo(context):
+// a, b = context
+//
+// To fork the closure, we need to set each value in the context tuple
+// as an explicit input to the fork node, and then within the closure
+// subgraph, replace the context unpacking value with the new graph input.
+// fork(foo) ->
+// def foo(a, b):
+void inlineForkedClosure(Node* fork_closure) {
+ Node* function_context_node = fork_closure->input()->node();
+
+ if (function_context_node->inputs().size() != 2 ||
+ function_context_node->inputs().at(0)->node()->kind() != prim::Function ||
+ function_context_node->inputs().at(1)->node()->kind() !=
+ prim::TupleConstruct) {
+ throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value";
+ }
+
+ Node* function = function_context_node->inputs().at(0)->node();
+ Node* context = function_context_node->inputs().at(1)->node();
+ auto fork_graph = function->g(attr::Subgraph)->copy();
+ auto g = fork_closure->owningGraph();
+ Node* fork_node = g->create(prim::fork, 1)
+ ->insertAfter(fork_closure)
+ ->setSourceRange(fork_closure->sourceRange());
+
+ if (fork_graph->inputs().size() != 1 ||
+ !fork_graph->inputs().at(0)->type()->cast<TupleType>()) {
+ throw ErrorReport(fork_node->sourceRange())
+ << "Cannot fork lambda with parameters";
+ }
+ auto fork_graph_context = fork_graph->inputs().at(0);
+ AT_ASSERT(fork_graph_context->uses().size() == 1);
+ auto fork_graph_unpack = fork_graph_context->uses().at(0).user;
+
+ for (size_t i = 0; i < context->inputs().size(); ++i) {
+ auto cont_input = context->inputs().at(i);
+ fork_node->addInput(cont_input);
+ auto inp = fork_graph->insertInput(i)->copyMetadata(cont_input);
+ fork_graph_unpack->outputs().at(i)->replaceAllUsesWith(inp);
+ }
+ fork_graph_unpack->destroy();
+ fork_graph->eraseInput(fork_graph->inputs().size() - 1);
+ fork_node->output()->copyMetadata(fork_closure->output());
+ fork_closure->output()->replaceAllUsesWith(fork_node->output());
+ fork_closure->destroy();
+ fork_node->g_(attr::Subgraph, fork_graph);
+ runCleanupPasses(fork_graph, /*convert_to_ssa */ false);
+}
+
+void inlineForkedClosures(Block* block) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* n = *it;
+ it++;
+ switch (n->kind()) {
+ case prim::forkClosure: {
+ inlineForkedClosure(n);
+ } break;
+ default: {
+ for (Block* b : n->blocks()) {
+ inlineForkedClosures(b);
+ }
+ } break;
+ }
+ }
+}
+
+void inlineForkedClosures(std::shared_ptr<Graph>& to_clean) {
+ inlineForkedClosures(to_clean->block());
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/inline_forked_closures.h b/torch/csrc/jit/passes/inline_forked_closures.h
new file mode 100644
index 0000000..134a9fd
--- /dev/null
+++ b/torch/csrc/jit/passes/inline_forked_closures.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+TORCH_API void inlineForkedClosures(std::shared_ptr<Graph>& to_clean);
+
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/lift_closures.cpp b/torch/csrc/jit/passes/lift_closures.cpp
new file mode 100644
index 0000000..94f4960
--- /dev/null
+++ b/torch/csrc/jit/passes/lift_closures.cpp
@@ -0,0 +1,80 @@
+#include <torch/csrc/jit/passes/lift_closures.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/compiler.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// Closures are initially emitted as prim::Function nodes with a single block.
+// Here, we convert the block to a subgraph, adding all closed over variables
+// as a context tuple input to the closure node.
+// At this point the closure has already undergone conversion to SSA,
+// so closed over variables will just be value * that are not set in the
+// closure block.
+// Within the closure subgraph, the context tuple is unpacked and the unpacked
+// values are used for closed over values.
+void liftClosure(Node* closure) {
+ auto block = closure->blocks().at(0);
+ auto subgraph = std::make_shared<Graph>();
+ // closures/forks can be nested, so use closure owning graph
+ auto g = closure->owningGraph();
+ Node* pack_context =
+ g->create(prim::TupleConstruct, {}, 1)->insertAfter(closure);
+ Value* context = subgraph->addInput("context");
+ // cannot use createTupleUnpack because the type is not known yet
+ Node* unpack_context =
+ subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
+
+ std::unordered_map<Value*, Value*> captures;
+ auto env = [&](Value* v) -> Value* {
+ auto it = captures.find(v);
+ if (it != captures.end()) {
+ return it->second;
+ }
+ pack_context->addInput(v);
+ Value* r = unpack_context->addOutput()->copyMetadata(v);
+ captures[v] = r;
+ return r;
+ };
+ subgraph->block()->cloneFrom(block, env);
+ auto context_type = TupleType::create(
+ fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
+ context->setType(context_type);
+ pack_context->output()->setType(context_type);
+ auto closure_tuple =
+ g->create(prim::TupleConstruct, {}, 1)->insertAfter(pack_context);
+ closure->output()->replaceAllUsesWith(closure_tuple->output());
+ closure_tuple->addInput(closure->output());
+ closure_tuple->addInput(pack_context->output());
+ closure_tuple->output()->setType(
+ TupleType::create({closure->output()->type(), context_type}));
+ closure->eraseBlock(0);
+ closure->g_(attr::Subgraph, std::move(subgraph));
+ runCleanupPasses(closure->g(attr::Subgraph), /*convert_to_ssa*/ false);
+}
+
+void liftClosures(Block* block) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* n = *it;
+ it++;
+ switch (n->kind()) {
+ case prim::Function: {
+ liftClosure(n);
+ } break;
+ default: {
+ for (Block* b : n->blocks()) {
+ liftClosures(b);
+ }
+ }
+ }
+ }
+}
+
+void liftClosures(const std::shared_ptr<Graph>& to_clean) {
+ liftClosures(to_clean->block());
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/lift_closures.h b/torch/csrc/jit/passes/lift_closures.h
new file mode 100644
index 0000000..fa139b3
--- /dev/null
+++ b/torch/csrc/jit/passes/lift_closures.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+TORCH_API void liftClosures(const std::shared_ptr<Graph>& graph);
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 3b1215e..ff440c7 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -6,8 +6,12 @@
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/inline_forked_closures.h>
#include <torch/csrc/jit/passes/inliner.h>
+#include <torch/csrc/jit/passes/lift_closures.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/convert_to_ssa.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/final_returns.h>
#include <torch/csrc/jit/script/parser.h>
@@ -28,6 +32,7 @@
using FunctionTable = std::unordered_map<std::string, Function&>;
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
+using TypeTable = std::unordered_map<std::string, TypePtr>;
using AttributeMap = std::unordered_map<std::string, Const>;
using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
@@ -153,49 +158,24 @@
return std::make_shared<MagicMethod>(name, base);
}
-// we consider _N where N is a number, to be a non-meaningful name
-// and do not record it as a unique name. This allows python printing to
-// be able to export and import more consistently named graphs
-static bool meaningfulName(const std::string& name) {
- if (name.size() == 0)
- return false;
- if (name[0] == '$')
- return false;
- if (name[0] != '_')
- return true;
- for (size_t i = 1; i < name.size(); ++i) {
- if (!isdigit(name[i]))
- return true;
- }
- return false;
-}
-
// Auxiliary data structure for desugaring variable binding into our always
// explicitly scoped language as we descend down nested control structures in
// the frontend (which themselves don't introduce scopes)
//
-// The algorithm is roughly as follows:
-// 1) While emitting a block within a control operator, add inputs and outputs
-// from the block for each value referenced (both "reads" and "writes").
-// This sets the value up as a candidate loop carried dependency.
-// 2) When we reach the end of the block, examine all the values in the current
-// scope's value map. If the name also resides in an outer scope with a
-// different Value*, this is a true loop-carried dependency. If not, this
-// value was not assigned to. Replace all references to the block input
-// with the Value* pointed to in the tightest enclosing scope. Then delete
-// that block input and output.
-// 3) When we emit the actual control operator, take all of the loop-carried
-// dependency values as inputs and return them as outputs from the control
-// op
+// The Environment keeps track of two tables, one for values which are not first
+// class and a type table for values which are. When a first class value
+// is set in the environment, we emit a prim::Store which sets the
+// name of the variable to approriate type, and when a first-class value is
+// referenced we emit a prim::Load that generates a value of the appropriate
+// type.
//
-// Note that an alternative implementation could only add the loop-carried dep
-// inputs and outputs when we see a value that is mutated. This, however
-// requires replacing all references to that value *within the current
-// block* with a new input. That is to say: we need to traverse the pre-
-// decessor nodes and replace inputs that reference that value with the
-// newly-created input. This could be made less expensive with a change to
-// the IR API, but for now we choose to pessimisitically create inputs and
-// delete unnecessary ones later with replaceAllusesWith().
+// a = 1
+// print(a)
+// becomes:
+// = prim::Store[name="a"](%a.1)
+// %a : int = prim::Load[name="a"]()
+// prim::Print(%a)
+
struct Environment {
Environment(
Function& method,
@@ -209,7 +189,6 @@
Function& method;
ResolverPtr resolver;
- std::vector<std::string> captured_inputs;
std::unordered_map<std::string, std::function<std::string()>> error_messages;
Block* b;
@@ -239,11 +218,30 @@
}
}
+ SugaredValuePtr insertLoad(const std::string& name, const TypePtr& type) {
+ auto g = b->owningGraph();
+ auto load = g->insertNode(g->createLoad(name, type));
+ if (meaningfulName(name)) {
+ load->output()->setUniqueName(name);
+ }
+ return std::make_shared<SimpleValue>(load->output());
+ }
+
+ void insertStore(const std::string& name, const SourceRange& loc, Value* v) {
+ auto g = b->owningGraph();
+ auto store = g->insertNode(g->createStore(name, v))->setSourceRange(loc);
+ type_table[name] = store->input()->type();
+ }
+
SugaredValuePtr findInThisFrame(const std::string& name) {
auto it = value_table.find(name);
if (it != value_table.end()) {
return it->second;
}
+ auto it2 = type_table.find(name);
+ if (it2 != type_table.end()) {
+ return insertLoad(name, it2->second);
+ }
return nullptr;
}
@@ -251,6 +249,10 @@
return next ? next->findInAnyFrame(name) : nullptr;
}
+ void setType(const std::string& name, TypePtr type) {
+ type_table[name] = std::move(type);
+ }
+
SugaredValuePtr findInAnyFrame(const std::string& name) {
for (auto runner = this; runner; runner = runner->next.get()) {
if (auto r = runner->findInThisFrame(name)) {
@@ -260,64 +262,9 @@
return nullptr;
}
- Value* getValueInThisFrame(const SourceRange& loc, const std::string& name) {
- return value_table.at(name)->asValue(loc, method);
- }
-
- SugaredValuePtr createCapturedInput(Value* orig, const std::string& name) {
- // insert the captured input alphabetically in the capture list.
- // this ensures consistency of the order of loop-carried dependencies
- // even when the use in the loop is in a different order
- size_t insert_pos = 0;
- while (insert_pos < captured_inputs.size() &&
- name > captured_inputs[insert_pos]) {
- insert_pos++;
- }
- captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
-
- // Create the input
- const size_t loop_carried_block_inputs_offset = 1;
- Value* new_input =
- b->insertInput(loop_carried_block_inputs_offset + insert_pos)
- ->setType(orig->type());
-
- // Associate this name with this value
- auto sv = std::make_shared<SimpleValue>(new_input);
- value_table[name] = sv;
-
- return sv;
- }
-
- SugaredValuePtr createCapturedInputIfNeeded(
- const SourceRange& loc,
- const std::string& ident) {
- auto in_frame = findInThisFrame(ident);
- if (in_frame) {
- return in_frame;
- }
-
- // recursively handles the case where parent blocks are also loops
- auto from_parent =
- next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
-
- // recursively create the captured input if it is the loop block
- if (from_parent && getBlockOwningKind() == prim::Loop) {
- if (Value* simple_val = asSimple(from_parent))
- from_parent = createCapturedInput(simple_val, ident);
- }
- return from_parent;
- }
-
Block* block() {
return b;
}
- Symbol getBlockOwningKind() {
- Symbol owning_kind = Symbol();
- if (b->owningNode()) {
- owning_kind = b->owningNode()->kind();
- }
- return owning_kind;
- }
void setVar(const SourceRange& loc, const std::string& name, Value* value) {
setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
@@ -375,9 +322,11 @@
throw ErrorReport(loc) << errMsg.str();
}
}
- if (as_simple_value)
- createCapturedInputIfNeeded(loc, name);
- value_table[name] = std::move(value);
+ if (as_simple_value) {
+ insertStore(name, loc, std::move(as_simple_value));
+ } else {
+ value_table[name] = std::move(value);
+ }
}
SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
@@ -391,7 +340,7 @@
const std::string& ident,
const SourceRange& range,
bool required = true) {
- auto retval = createCapturedInputIfNeeded(range, ident);
+ auto retval = findInAnyFrame(ident);
if (!retval) {
static std::unordered_map<std::string, SugaredValuePtr> globals = {
@@ -464,40 +413,16 @@
return getSugaredVar(ident, range)->asValue(range, method);
}
- // Given that after emitting statements in a block, we've added block inputs
- // for all value references and assignments, delete inputs for which there was
- // no assignment, only references.
- void deleteExtraInputs(const SourceRange& loc) {
- // note: skip i == 0, it is the loop trip count for inputs
- // and the loop condition for outputs.
- // captured_inputs is indexed by i - 1 since it only contains loop
- // carried dependencies
- // inputs: loop_counter, lcd0, lcd1, ...
- // outputs: loop_condition, lcd0, lcd1, ...
- // captured_inputs: lcd0, lcd1, ...
- AT_ASSERT(b->inputs().size() == b->outputs().size());
- AT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
- for (size_t i = b->inputs().size() - 1; i > 0; i--) {
- // nothing changed along this loop
- if (b->inputs()[i] == b->outputs()[i]) {
- auto name = captured_inputs[i - 1];
- Value* orig = findInParentFrame(name)->asValue(loc, method);
- b->inputs()[i]->replaceAllUsesWith(orig);
- b->eraseInput(i);
- b->eraseOutput(i);
- captured_inputs.erase(captured_inputs.begin() + i - 1);
- }
- }
- }
std::vector<std::string> definedVariables() {
std::vector<std::string> result;
- for (auto& kv : value_table) {
+ for (auto& kv : type_table) {
result.push_back(kv.first);
}
return result;
}
private:
+ TypeTable type_table;
ValueTable value_table;
};
@@ -594,17 +519,6 @@
return old_frame;
}
- void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
- // remove any uses of tuples that we inserted that are not needed
- if (!script::getFirstClassMode()) {
- Inline(*to_clean);
- }
- LowerSimpleTuples(to_clean);
- ConstantPooling(to_clean);
- // For jitter
- CanonicalizeOutputs(to_clean);
- }
-
FunctionSchema emitDef(const Def& def, const Self& self, Block* block) {
auto schema = extractSchemaFromDef(def, self);
// TODO need guards on init returning none
@@ -787,11 +701,13 @@
if (meaningfulName(name)) {
new_input->setUniqueName(name);
}
- environment_stack->setVar((*it).ident().range(), name, new_input);
-
// Record the type for the schema and set the Type on the Value*
arguments.push_back(schema.arguments().at(arg_annotation_idx++));
new_input->setType(arguments.back().type());
+
+ // NB: set type of new_input before setVar call so the Store is
+ // typed appropriately
+ environment_stack->setVar((*it).ident().range(), name, new_input);
}
return arguments;
}
@@ -811,39 +727,11 @@
void emitStatements(const List<Stmt>& statements) {
return emitStatements(statements.begin(), statements.end());
}
- std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
- auto subgraph = std::make_shared<Graph>();
- // note: type is set later on pack_context and context when we know it
- Node* pack_context =
- graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
- Value* context = subgraph->addInput("context");
- // cannot use createTupleUnpack because the type is not known yet
- Node* unpack_context =
- subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
- std::unordered_map<Value*, Value*> captures;
- auto env = [&](Value* v) -> Value* {
- auto it = captures.find(v);
- if (it != captures.end()) {
- return it->second;
- }
- pack_context->addInput(v);
- Value* r = unpack_context->addOutput()->copyMetadata(v);
- captures[v] = r;
- return r;
- };
- subgraph->block()->cloneFrom(block, env);
- auto context_type = TupleType::create(
- fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
- pack_context->output()->setType(context_type);
- context->setType(context_type);
- return std::make_pair(std::move(subgraph), pack_context->output());
- }
// XXX - right now closures are used _only_ for defining gradients internally
// There are several unfinished aspects that make them unusable generally
// 1. We do not have a type, ivalue, operator to represent prim::Function, so
// closure_node has type None
- // and any graphs that contain it cannot be run
// 2. There is no export logic for it yet, so it cannot be
// exported/python_printed
// 3. There is nothing preventing the assignment of already existing variables
@@ -851,32 +739,33 @@
// the changes to those variables will just get forgotten.
// 4. There is no parsing support in frontend.py, this is intentional since it
// prevents people from accidentally using this feature.
- void emitClosure(const Def& def) {
+ std::shared_ptr<ClosureValue> emitClosure(
+ const std::function<void(Block*)>& emit_body) {
Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
- closure_node->output()->setType(
- NoneType::get()); // it is not a real thing yet, so just say the type is
- // none.
+ // it is not a real thing yet, so just say the type is None
+ closure_node->output()->setType(NoneType::get());
Block* block = closure_node->addBlock();
{
WithInsertPoint guard(block);
pushFrame(block, /*starts_def=*/true);
+ emit_body(block);
+ popFrame(/*ends_def=*/true);
+ }
+ return std::make_shared<ClosureValue>(closure_node->output());
+ }
+
+ void emitClosure(const Def& def) {
+ // invoked once the closure block is set as the enviroment
+ auto emit_body = [&](Block* closure_block) {
emitDef(
def,
nullptr,
- block); // ignore schema return, we just wont use it for now since we
- // never create a Method for the closure
- popFrame(/*ends_def=*/true);
- }
- std::shared_ptr<Graph> subgraph;
- Value* context;
- std::tie(subgraph, context) = lambdaLift(block);
- runCleanupPasses(subgraph);
- closure_node->eraseBlock(0);
- closure_node->g_(attr::Subgraph, std::move(subgraph));
- auto tup =
- graph->insertNode(graph->createTuple({closure_node->output(), context}))
- ->output();
- environment_stack->setVar(def.name().range(), def.name().name(), tup);
+ closure_block); // ignore schema return, we just wont use it for now
+ // since we never create a Method for the closure
+ };
+ auto closure_value = emitClosure(emit_body);
+ environment_stack->setSugaredVar(
+ def.name().range(), def.name().name(), closure_value);
}
void emitReturn(const Return& stmt) {
@@ -1208,33 +1097,51 @@
// ordered set, because we want deterministic graph output
std::set<std::string> mutated_variables;
+ // When we access either the true or false environment,
+ // we need to set the insertion point so the prim::Load is inserted
+ // into the right block.
+ // if var is only defined in one branch save error in case it's used later
for (auto& v : save_true->definedVariables()) {
- if (save_false->findInAnyFrame(v)) {
- mutated_variables.insert(v);
- } else {
- ErrorReport error(stmt);
- environment_stack->setVariableTypeError(v, [=]() -> std::string {
- error << v << " is not defined in the false branch";
- return error.what();
- });
+ {
+ WithInsertPoint insert(false_block);
+ if (save_false->findInAnyFrame(v)) {
+ mutated_variables.insert(v);
+ } else {
+ ErrorReport error(stmt);
+ environment_stack->setVariableTypeError(v, [=]() -> std::string {
+ error << v << " is not defined in the false branch";
+ return error.what();
+ });
+ }
}
}
for (auto& v : save_false->definedVariables()) {
- if (save_true->findInAnyFrame(v)) {
- mutated_variables.insert(v);
- } else {
- ErrorReport error(stmt);
- environment_stack->setVariableTypeError(v, [=]() -> std::string {
- error << v << " is not defined in the true branch";
- return error.what();
- });
+ {
+ WithInsertPoint insert(true_block);
+ if (save_true->findInAnyFrame(v)) {
+ mutated_variables.insert(v);
+ } else {
+ ErrorReport error(stmt);
+ environment_stack->setVariableTypeError(v, [=]() -> std::string {
+ error << v << " is not defined in the true branch";
+ return error.what();
+ });
+ }
}
}
// Register outputs in each block
for (const auto& x : mutated_variables) {
- auto tv = save_true->getVar(x, stmt.range());
- auto fv = save_false->getVar(x, stmt.range());
+ Value* tv;
+ Value* fv;
+ {
+ WithInsertPoint insert(true_block);
+ tv = save_true->getVar(x, stmt.range());
+ }
+ {
+ WithInsertPoint insert(false_block);
+ fv = save_false->getVar(x, stmt.range());
+ }
auto unified = unifyTypes(tv->type(), fv->type());
// attempt to unify the types. we allow variables to be set to different
@@ -1264,10 +1171,7 @@
continue;
}
}
- true_block->registerOutput(tv);
- false_block->registerOutput(fv);
- environment_stack->setVar(
- stmt.range(), x, n->addOutput()->setType(*unified));
+ environment_stack->setType(x, *unified);
}
}
@@ -1354,10 +1258,6 @@
current_element_assigner,
c10::optional<Expr> cond,
Value* max_trip_count_val = nullptr) {
- Value* cond_val = nullptr;
- Node* n = graph->insertNode(create(prim::Loop, range, 0));
- WithInsertPoint guard(n);
-
if (!max_trip_count_val) {
max_trip_count_val = materializeConstant(
std::numeric_limits<int64_t>::max(),
@@ -1366,8 +1266,12 @@
integral_constants);
}
- cond_val = (cond) ? emitCond(cond.value())
- : graph->insertConstant(true, nullptr, range);
+ Value* cond_val = (cond) ? emitCond(cond.value())
+ : graph->insertConstant(true, nullptr, range);
+
+ Node* n = graph->insertNode(create(prim::Loop, range, 0));
+ WithInsertPoint guard(n);
+
n->addInput(max_trip_count_val);
n->addInput(cond_val);
auto* body_block = n->addBlock();
@@ -1386,33 +1290,13 @@
emitStatements(body);
- // Also emit the conditional
- cond_val = (cond) ? emitCond(cond.value())
- : graph->insertConstant(true, nullptr, range);
- body_block->registerOutput(cond_val);
- auto body_frame = popFrame();
- auto outer_frame = environment_stack;
+ Value* block_condition = (cond)
+ ? emitCond(cond.value())
+ : graph->insertConstant(true, nullptr, range);
- // Add block outputs to correspond to each captured input
- // some of these will be removed.
- for (const auto& x : body_frame->captured_inputs) {
- auto fv = body_frame->getValueInThisFrame(range, x);
- body_block->registerOutput(fv);
- }
+ body_block->registerOutput(block_condition);
- // Remove inputs for values that did not mutate within the
- // block
- body_frame->deleteExtraInputs(range);
-
- // register node inputs/outputs for the true loop carried deps,
- for (size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
- auto x = body_frame->captured_inputs[i];
- n->addInput(outer_frame->getVar(x, range));
- // body_block->inputs(): loop_counter, lcd0, lcd1, ...
- // captured_inputs: lcd0, lcd1, ...
- auto typ = body_block->inputs()[i + 1]->type();
- outer_frame->setVar(range, x, n->addOutput()->setType(typ));
- }
+ popFrame();
}
}
@@ -2479,31 +2363,42 @@
return graph->insertConstant(stack[0], nullptr, tree->range());
}
- // This function extract a new graph from its original subgraph
std::shared_ptr<SugaredValue> emitForkExpr(
SourceRange loc,
const std::shared_ptr<SugaredValue>& forked,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes) {
- // Build the fork node without inputs
- auto fork_node = method.graph()
- ->insertNode(method.graph()->create(prim::fork, 1))
- ->setSourceRange(loc);
- auto body_block = fork_node->addBlock();
+ auto g = method.graph();
+ Node* fork_node;
+ TypePtr out_type;
- // Build a template of the graph to be executed
- Value* node_output;
+ fork_node = g->insertNode(method.graph()->create(prim::forkClosure, 1))
+ ->setSourceRange(loc);
+
+ // We create a fork by emitting a closure and setting the closure output
+ // into the fork input. If a closure doesn't already exist, we create one.
{
- WithInsertPoint guard(body_block);
- auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
- auto fn_simple_output = fn_sugared_output->asValue(loc, method);
- body_block->registerOutput(fn_simple_output);
- node_output = fork_node->output()->setType(
- FutureType::create(fn_simple_output->type()));
+ WithInsertPoint insert(fork_node);
+ if (ClosureValue* sv = dynamic_cast<ClosureValue*>(forked.get())) {
+ Value* closure_output = sv->asValue(loc, method);
+ Block* closure_block = closure_output->node()->blocks().at(0);
+ TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1);
+ out_type = closure_block->outputs().at(0)->type();
+ fork_node->addInput(closure_output);
+ } else {
+ auto emit_closure_body = [&](Block* closure_block) {
+ auto fn_sugared_output =
+ forked->call(loc, method, inputs, attributes, 1);
+ auto fn_simple_output = fn_sugared_output->asValue(loc, method);
+ closure_block->registerOutput(fn_simple_output);
+ out_type = fn_simple_output->type();
+ };
+ auto closure_value = emitClosure(emit_closure_body);
+ fork_node->addInput(closure_value->asValue(loc, method));
+ }
}
- // Lambda lift block(0) into attr::Subgraph
- lambdaLiftFork(fork_node);
- runCleanupPasses(fork_node->g(attr::Subgraph));
+ Value* node_output =
+ fork_node->output()->setType(FutureType::create(out_type));
return std::make_shared<SimpleValue>(node_output);
}
@@ -2944,19 +2839,19 @@
auto tuple_typ = tuple_val->type()->cast<TupleType>();
auto elems = tuple_typ->elements();
TypePtr output_type;
+ if (idx_val->type() != IntType::get()) {
+ throw ErrorReport(loc) << "tuple index must be an integer";
+ }
auto idx = toIValue(idx_val);
if (!idx) {
if (elems.size() == 0 ||
!convertibleToList(tuple_typ, ListType::create(elems[0]))) {
throw ErrorReport(loc)
<< "Cannot index into a " << tuple_typ->python_str()
- << " with a non-constant index because we cannot resolve the output type";
+ << " with a non-integer literal because we cannot resolve the output type";
}
output_type = elems[0];
} else {
- if (!idx->isInt()) {
- throw ErrorReport(loc) << "tuple index must be an integer";
- }
auto adj_index = getAdjTupleIndex(
loc, tuple_typ, idx->toInt(), /*allow_out_of_bounds*/ false);
output_type = elems[adj_index];
@@ -2977,7 +2872,7 @@
}
int64_t getSliceInd(Value* idx_val, const SourceRange& loc) {
- at::optional<IValue> ivalue = toIValue(idx_val);
+ auto ivalue = toIValue(idx_val);
if (ivalue && ivalue->isInt()) {
return ivalue->to<int64_t>();
} else {
@@ -3166,6 +3061,45 @@
define(definitions, resolvers, self);
}
+void runCleanupPasses(std::shared_ptr<Graph>& to_clean, bool convert_ssa) {
+ // the graph including closures is converted to ssa in the first pass,
+ // so subsequent cleanups do not need reconvert it
+ if (convert_ssa) {
+ ConvertToSSA(to_clean);
+ }
+ // NB ORDERING: SSA conversion has to occur before
+ // lifting of closures and forks, this way closures are converted
+ // to SSA while part of their original graph, and closures are ready to
+ // be inlined into forked closures
+ liftClosures(to_clean);
+ inlineForkedClosures(to_clean);
+ if (!script::getFirstClassMode()) {
+ Inline(*to_clean);
+ }
+ // remove any uses of tuples that we inserted that are not needed
+ LowerSimpleTuples(to_clean);
+ ConstantPooling(to_clean);
+ // For jitter
+ CanonicalizeOutputs(to_clean);
+}
+
+// we consider _N where N is a number, to be a non-meaningful name
+// and do not record it as a unique name. This allows python printing to
+// be able to export and import more consistently named graphs
+bool meaningfulName(const std::string& name) {
+ if (name.size() == 0)
+ return false;
+ if (name[0] == '$')
+ return false;
+ if (name[0] != '_')
+ return true;
+ for (size_t i = 1; i < name.size(); ++i) {
+ if (!isdigit(name[i]))
+ return true;
+ }
+ return false;
+}
+
void lambdaLiftFork(Node* fork_node) {
// Fork a new graph from its orignal owning graph
auto forked_graph = std::make_shared<Graph>();
@@ -3182,13 +3116,13 @@
}
return uncaptures_map[v];
};
-
forked_graph->block()->cloneFrom(body_block, env);
// Separate the subgraph and clean up the orignal one
fork_node->g_(attr::Subgraph, forked_graph);
fork_node->eraseBlock(0);
}
+
} // namespace script
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index a832b78..e53fd47 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -14,6 +14,11 @@
namespace jit {
namespace script {
+TORCH_API void runCleanupPasses(
+ std::shared_ptr<Graph>& to_clean,
+ bool convert_ssa = true);
+
+TORCH_API bool meaningfulName(const std::string& name);
TORCH_API void lambdaLiftFork(Node* fork_node);
} // namespace script
diff --git a/torch/csrc/jit/script/convert_to_ssa.cpp b/torch/csrc/jit/script/convert_to_ssa.cpp
new file mode 100644
index 0000000..03f42a5
--- /dev/null
+++ b/torch/csrc/jit/script/convert_to_ssa.cpp
@@ -0,0 +1,237 @@
+#include <torch/csrc/jit/script/convert_to_ssa.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/passes/inline_forked_closures.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/mini_environment.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// At the beginning of the pass the Graph has already undergone type checking,
+// and writes or reads to a variable are emitted as Loads and Stores in the
+// graph. a = 1 print(a) is represented as:
+//
+// %a.1 : int = prim::Constant[value=1]()
+// prim::Store[name="a"](%a.1)
+// %a : int = prim::Load[name="a"]()
+// prim::Print(%a)
+//
+// First, this pass recursively adds the Loads & Stores to control flow nodes
+// Then the graph is converted to SSA form.
+
+using ValueEnvironment = MiniEnvironment<Value*>;
+using TypeEnvironment = MiniEnvironment<TypePtr>;
+
+// Adds Loads & Stores to Loops & Ifs
+struct ControlFlowLoadStores {
+ static void addBlockInput(
+ Block* b,
+ const TypePtr& type,
+ const std::string& name) {
+ auto g = b->owningGraph();
+ g->createStore(name, b->addInput(name)->setType(type))
+ ->insertAfter(b->param_node());
+ }
+
+ static void addBlockOutput(
+ Block* b,
+ const TypePtr& type,
+ const std::string& name) {
+ WithInsertPoint insert(b);
+ auto g = b->owningGraph();
+ auto block_output = g->insertNode(g->createLoad(name, type))->output();
+ b->registerOutput(block_output);
+ }
+
+ static void addNodeOutput(
+ Node* n,
+ const TypePtr& type,
+ const std::string& name) {
+ auto out = n->addOutput()->setType(type);
+ if (meaningfulName(name)) {
+ out->setUniqueName(name);
+ }
+ auto g = n->owningGraph();
+ g->createStore(name, out)->insertAfter(n);
+ }
+
+ static void addNodeInput(
+ Node* n,
+ const TypePtr& type,
+ const std::string& name) {
+ auto g = n->owningGraph();
+ auto inp = g->createLoad(name, type)->insertBefore(n)->output();
+ n->addInput(inp);
+ }
+
+ void addIfLoadStores(Node* n) {
+ auto true_block = n->blocks().at(0);
+ auto false_block = n->blocks().at(1);
+
+ auto true_vars = addControlFlowLoadStores(true_block);
+ auto false_vars = addControlFlowLoadStores(false_block);
+ std::set<std::string> mutated_variables;
+
+ for (auto& v : true_vars->definedVariables()) {
+ if (false_vars->findInAnyFrame(v)) {
+ mutated_variables.insert(v);
+ }
+ }
+ for (auto& v : false_vars->definedVariables()) {
+ if (true_vars->findInAnyFrame(v)) {
+ mutated_variables.insert(v);
+ }
+ }
+
+ // Following the same logic as emitIfElseBlocks in compiler.cpp,
+ // we emit a node output if the variable is defined in each block
+ // and the types of each block can be unified
+
+ for (const auto& x : mutated_variables) {
+ auto true_type = true_vars->findInAnyFrame(x);
+ auto false_type = false_vars->findInAnyFrame(x);
+ auto unified = unifyTypes(true_type, false_type);
+ if (!unified) {
+ continue;
+ }
+
+ addBlockOutput(true_block, true_type, x);
+ addBlockOutput(false_block, false_type, x);
+ addNodeOutput(n, *unified, x);
+ }
+ }
+
+ // loop_carried_outputs* = Loop(max_trip_count, start_condition,
+ // loop_carried_inputs*)
+ // block0(loop_counter, loop_carried_block*) {
+ // <body>
+ // -> (continue_condition, loop_carried_block_outputs*)
+ // }
+ // all loop_carried_... lists are the same length and represent the value of
+ // loop-carried variables whose definitions are updated as the loop executes
+ // in a way that ensure single static assignment.
+ void addLoopLoadStores(Node* n) {
+ auto body_block = n->blocks().at(0);
+ auto loop_vars = addControlFlowLoadStores(body_block);
+ for (const auto& name : loop_vars->definedVariables()) {
+ // we require that the variable is defined outside the loop to be emitted,
+ // and we do not refine the type of the parent variable since the loop may
+ // not be entered.
+ auto parent_type = environment_stack->findInAnyFrame(name);
+ if (!parent_type) {
+ continue;
+ }
+
+ // Insert a store at the beginning of the loop block, so that all
+ // loads of the variable will use the loop carried value
+ addNodeInput(n, parent_type, name);
+ addBlockInput(body_block, parent_type, name);
+ addBlockOutput(body_block, parent_type, name);
+ addNodeOutput(n, parent_type, name);
+ }
+ }
+
+ std::shared_ptr<TypeEnvironment> addControlFlowLoadStores(Block* block) {
+ pushFrame(block);
+ for (Node* n : block->nodes()) {
+ switch (n->kind()) {
+ case prim::If: {
+ addIfLoadStores(n);
+ } break;
+ case prim::Loop: {
+ addLoopLoadStores(n);
+ } break;
+ case prim::Function: {
+ for (auto b : n->blocks()) {
+ addControlFlowLoadStores(b);
+ }
+ } break;
+ case prim::Store: {
+ environment_stack->setVar(n->s(attr::name), n->input()->type());
+ } break;
+ }
+ }
+ return popFrame();
+ }
+
+ void pushFrame(Block* b) {
+ environment_stack = std::make_shared<TypeEnvironment>(b, environment_stack);
+ }
+
+ std::shared_ptr<TypeEnvironment> popFrame() {
+ auto old_frame = environment_stack;
+ environment_stack = environment_stack->next;
+ return old_frame;
+ }
+
+ void run(std::shared_ptr<Graph>& graph) {
+ addControlFlowLoadStores(graph->block());
+ }
+
+ std::shared_ptr<TypeEnvironment> environment_stack = nullptr;
+};
+
+// Given a graph where outputs have been added to control flow nodes, and
+// loads and stores are represented in the graph, converts the graph to SSA
+struct SSATransformer {
+ void convertBlockToSSA(Block* block) {
+ pushFrame(block);
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ auto n = *it;
+ it++;
+ switch (n->kind()) {
+ case prim::If:
+ case prim::Loop:
+ case prim::Function: {
+ for (auto b : n->blocks()) {
+ convertBlockToSSA(b);
+ }
+ } break;
+ case prim::Store: {
+ environment_stack->setVar(n->s(attr::name), n->input());
+ n->destroy();
+ } break;
+ case prim::Load: {
+ auto name = n->s(attr::name);
+ auto var = environment_stack->findInAnyFrame(name);
+ TORCH_INTERNAL_ASSERT(
+ var, "Typechecking should ensure the variable name is set");
+ n->output()->replaceAllUsesWith(var);
+ n->destroy();
+ } break;
+ }
+ }
+ popFrame();
+ }
+
+ void pushFrame(Block* b) {
+ environment_stack =
+ std::make_shared<ValueEnvironment>(b, environment_stack);
+ }
+
+ std::shared_ptr<ValueEnvironment> popFrame() {
+ auto old_frame = environment_stack;
+ environment_stack = environment_stack->next;
+ return old_frame;
+ }
+
+ void run(std::shared_ptr<Graph>& graph) {
+ convertBlockToSSA(graph->block());
+ }
+
+ std::shared_ptr<ValueEnvironment> environment_stack = nullptr;
+};
+
+// Converting to SSA works in two parts. First we add outputs to control flow
+// nodes, then we stitch together Loads & Stores into SSA form.
+void ConvertToSSA(std::shared_ptr<Graph>& graph) {
+ ControlFlowLoadStores ctrl;
+ ctrl.run(graph);
+ SSATransformer ssa;
+ ssa.run(graph);
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/convert_to_ssa.h b/torch/csrc/jit/script/convert_to_ssa.h
new file mode 100644
index 0000000..e287cee
--- /dev/null
+++ b/torch/csrc/jit/script/convert_to_ssa.h
@@ -0,0 +1,18 @@
+#pragma once
+#include <functional>
+#include <memory>
+#include <string>
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// Convert a graph with Loads & Stores into SSA form
+TORCH_API void ConvertToSSA(std::shared_ptr<Graph>& graph);
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/mini_environment.h b/torch/csrc/jit/script/mini_environment.h
new file mode 100644
index 0000000..f3c2cfe
--- /dev/null
+++ b/torch/csrc/jit/script/mini_environment.h
@@ -0,0 +1,55 @@
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/jit/ir.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// Simple data structure for containing a type T in nested control blocks
+// Should only be used after initial compilation where type checking and
+// loads and stores are emitted
+
+template <typename T>
+struct MiniEnvironment {
+ MiniEnvironment(Block* b, std::shared_ptr<MiniEnvironment> next = nullptr)
+ : next(std::move(next)) {}
+
+ std::shared_ptr<MiniEnvironment<T>> next;
+
+ T findInThisFrame(const std::string& name) {
+ auto it = table.find(name);
+ if (it != table.end()) {
+ return it->second;
+ }
+ return nullptr;
+ }
+
+ T findInAnyFrame(const std::string& name) {
+ for (auto runner = this; runner; runner = runner->next.get()) {
+ if (auto r = runner->findInThisFrame(name)) {
+ return r;
+ }
+ }
+ return nullptr;
+ }
+
+ void setVar(const std::string& name, T value) {
+ table[name] = value;
+ }
+
+ std::vector<std::string> definedVariables() {
+ std::vector<std::string> result;
+ for (auto& kv : table) {
+ result.push_back(kv.first);
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+ }
+
+ private:
+ std::unordered_map<std::string, T> table;
+};
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h
index 28e14af..2200e52 100644
--- a/torch/csrc/jit/script/sugared_value.h
+++ b/torch/csrc/jit/script/sugared_value.h
@@ -241,6 +241,19 @@
std::shared_ptr<Function> callee_;
};
+struct TORCH_API ClosureValue : public SugaredValue {
+ ClosureValue(Value* value) : value_(value) {
+ TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Function);
+ }
+ std::string kind() const override {
+ return "closure";
+ }
+ Value* asValue(const SourceRange& range, Function& m) override {
+ return value_;
+ }
+ Value* value_;
+};
+
// defines how a method obtained from a module behaves in script
struct MethodValue : public SugaredValue {
MethodValue(Value* self, std::string method_name)