[ONNX] Redesign inplace conversion (#55033) (#56173)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56173
* Create `InplaceConverter` and `ValueTracker` to keep track of aliases of values throughout the graph. For a given value, a new alias is created every time when there is an inplace operation, SetAttr, or through nested blocks owned by If/Loop nodes.
* Fix bug where controlflow node output types are not set, when the complete node is unable to run ONNX shape inference due to containing non-onnx node.
* Add symbolic for `__not__` ~~and `prim_min`~~(update: moved to a separate PR), and update `index_put` opset9 to support case of assignment without providing indices.
* Bump ORT version in CI test.
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D27866138
Pulled By: SplitInfinity
fbshipit-source-id: ab5c9188740c50f783ceba4d54fda43c26e2fde7
diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh
index c1988d2..e66b7ae 100755
--- a/.jenkins/caffe2/test.sh
+++ b/.jenkins/caffe2/test.sh
@@ -170,7 +170,7 @@
# JIT C++ extensions require ninja, so put it into PATH.
export PATH="/var/lib/jenkins/.local/bin:$PATH"
if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
- pip install -q --user onnxruntime==1.6.0
+ pip install -q --user onnxruntime==1.7.0
fi
"$ROOT_DIR/scripts/onnx/test.sh"
fi
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index f33fd1b..29b8d2d 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -4569,6 +4569,24 @@
y = torch.randn(4, 5)
self.run_test(model, (x, y))
+ @skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
+ def test_list_append_nested_2(self):
+ class ListModel(torch.nn.Module):
+ def forward(self, x):
+ res = []
+ res_replicate = []
+ for i in range(x.size(0)):
+ if len(res) > 2:
+ for j in range(x.size(1)):
+ res.append(x[i][j])
+ res_replicate.append(res[-1])
+ res.append(res_replicate[-1])
+ return res, res_replicate
+
+ model = torch.jit.script(ListModel())
+ x = torch.randn(4, 4, 3, 4)
+ self.run_test(model, (x, ))
+
@skipIfUnsupportedMinOpsetVersion(11)
def test_list_pop(self):
class ListModel(torch.nn.Module):
@@ -4651,6 +4669,36 @@
y = torch.randn(4, 5)
self.run_test(model, (x, y))
+ @skipIfUnsupportedMinOpsetVersion(11)
+ def test_list_set(self):
+ class ListModel(torch.nn.Module):
+ def forward(self, x, y):
+ res = []
+ for i in range(x.size(0)):
+ res.append(x[i])
+ res[y] = x[y]
+ return res
+
+ model = torch.jit.script(ListModel())
+ x = torch.randn(12, 4)
+ y = torch.tensor(2, dtype=torch.long)
+ self.run_test(model, (x, y))
+
+ @skipIfUnsupportedMinOpsetVersion(13)
+ def test_list_idx_sum(self):
+ class ListModel(torch.nn.Module):
+ def forward(self, x, y):
+ indices = torch.arange(x.size(0))
+ res = []
+ for i in range(x.size(0)):
+ res.append(x[i])
+ return res[torch.sum(indices[:y])]
+
+ model = torch.jit.script(ListModel())
+ x = torch.randn(12, 4)
+ y = torch.tensor(2, dtype=torch.long)
+ self.run_test(model, (x, y))
+
@skipIfUnsupportedMinOpsetVersion(9)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
@@ -4830,6 +4878,125 @@
self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
+ @skipIfUnsupportedMinOpsetVersion(9)
+ def test_inplace_with_loop(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ a = torch.ones(12,)
+ for i in range(10):
+ a.add_(torch.ones(12,))
+ return a + x
+
+ m = M()
+ x = torch.randn(12,)
+ self.run_test(torch.jit.script(M()), (x))
+
+ @skipIfUnsupportedMinOpsetVersion(9)
+ def test_inplace_with_loop_2(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ _bias = torch.ones(12,)
+ a = torch.ones(12,) # used in loop, altered.
+ a_ref = a # not used in loop, should be altered.
+ b = x.clone() # used in loop, not be altered.
+ b_ref = b # not used in loop, should not be altered.
+ for i in range(10):
+ if i == 3:
+ for j in range(5):
+ a += _bias
+ _bias.add_(torch.ones(12,))
+ b = b + torch.ones(12,)
+
+ _bias.add_(torch.ones(12,))
+ a += _bias
+ # TODO: value for a_ref is incorrect.
+ # a_ref += torch.ones(12,)
+ b_ref += torch.ones(12,)
+ return _bias + x, a, b, b_ref
+
+ m = M()
+ x = torch.zeros(12,)
+ self.run_test(torch.jit.script(M()), (x))
+
+ @skipIfUnsupportedMinOpsetVersion(11)
+ def test_inplace_attr_with_loop(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._bias = torch.arange(12,)
+
+ def forward(self, x):
+ self._bias = torch.arange(12,)
+ for i in range(10):
+ if i == 3:
+ for j in range(5):
+ self._bias += torch.arange(12,)
+ return self._bias + x
+
+ m = M()
+ x = torch.zeros(12,)
+ self.run_test(torch.jit.script(M()), (x))
+
+ @skipIfUnsupportedMinOpsetVersion(11)
+ def test_inplace_attr_copy_with_loop(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._bias = torch.arange(12,)
+
+ def forward(self, x):
+ self._bias = torch.arange(12,)
+ for i in range(10):
+ if i == 3:
+ for j in range(5):
+ self._bias.copy_(torch.arange(12,))
+ self._bias.copy_(self._bias + torch.arange(12,))
+
+ self._bias.copy_(self._bias + torch.arange(12,))
+ return self._bias + x
+
+ m = M()
+ x = torch.zeros(12,)
+ self.run_test(torch.jit.script(M()), (x))
+
+ @skipIfUnsupportedMinOpsetVersion(14) # Need onnx::identity of sequence in opset 14
+ def test_inplace_sequence_with_loop(self):
+ class M(torch.nn.Module):
+ def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x):
+ batch_size = x.shape[0]
+ for i in range(batch_size):
+ if done[i]:
+ continue
+
+ beam_idx = 0
+ for _, token in enumerate(x[i]):
+ beam_hyps.append(token)
+ beam_idx += 1
+
+ if beam_idx == 6:
+ break
+
+ done[i] = len(beam_hyps) > 4
+
+ return beam_hyps, done
+
+ def forward(self, x):
+ beam_hyps: List[torch.Tensor] = []
+ batch_size = x.shape[0]
+ cur_len = 0
+ max_len = x.shape[1]
+ done = torch.zeros(batch_size, dtype=torch.bool)
+ while cur_len < max_len:
+ beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
+ cur_len = cur_len + 1
+
+ return beam_hyps
+
+ m = torch.jit.script(M())
+ x = torch.randn(8, 4, 3)
+ self.run_test(torch.jit.script(M()), (x))
+
+
@disableScriptTest() # Sort with dynamic dim not supported in ONNX
def test_sort(self):
class SortModel(torch.nn.Module):
@@ -7602,6 +7769,37 @@
self.run_test(model, (x, anchors))
@skipIfUnsupportedMinOpsetVersion(11)
+ def test_set_attr_5(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.conv = torch.nn.Conv1d(10, 3, 3)
+ self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
+
+ def set_cell_anchors(self, anchors):
+ self.conv.weight = torch.arange(10)
+ for i in range(10):
+ if i == 3:
+ for j in range(10):
+ w = self.conv.weight
+ self.conv.weight = torch.arange(10) + w
+
+ self.conv.weight = self.conv.weight + torch.arange(10)
+ # NOTE: `is not None` and `assert` is for passing torchscript.
+ if self.conv.bias is not None:
+ a = self.conv.bias
+ assert a is not None
+ self.conv.bias = anchors + a
+
+ def forward(self, anchors):
+ self.set_cell_anchors(anchors)
+ return self.conv.weight, self.conv.bias
+
+ model = torch.jit.script(MyModule())
+ anchors = torch.ones(3, 10, 3)
+ self.run_test(model, (anchors))
+
+ @skipIfUnsupportedMinOpsetVersion(11)
def test_set_attr_in_loop(self):
class MyModule(torch.nn.Module):
def __init__(self):
@@ -7698,7 +7896,11 @@
model = Example(10)
random_data = torch.rand((1, 5, 30, 30))
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
- self.run_test(model, (random_data, empty_tensor))
+ random_state = torch.rand((1, 1, 10, 30, 30))
+ self.run_test(model, (random_data, empty_tensor),
+ input_names=['data', 'state'],
+ dynamic_axes={'state': [0, 1, 2, 3, 4]},
+ test_with_inputs=[(random_data, random_state)])
@skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_if_3(self):
@@ -7768,6 +7970,41 @@
empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
self.run_test(model, (random_data, empty_tensor))
+
+ @skipIfUnsupportedMinOpsetVersion(11)
+ def test_index_put_if_5(self):
+ @torch.jit.script
+ def check_init(input_data, hidden_size, prev_state):
+ # type: (torch.Tensor, int, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ batch_size = input_data.size(0)
+ spatial_size_0 = input_data.size(2)
+ spatial_size_1 = input_data.size(3)
+ # generate empty prev_state, if None is provided
+ state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
+ state = torch.zeros(state_size, device=input_data.device)
+ state_ref = state
+ if prev_state.size(0) == 0:
+ state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
+ state = state + 3
+ state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
+ else:
+ state = state + 2
+ return state, state_ref
+
+ class Example(torch.nn.Module):
+ def __init__(self, hidden_size):
+ super().__init__()
+ self.hidden_size = hidden_size
+
+ def forward(self, input_data, prev_state):
+ prev_state, state_ref = check_init(input_data, self.hidden_size, prev_state)
+ return prev_state, state_ref
+
+ model = Example(4)
+ random_data = torch.rand((1, 5, 4, 4))
+ empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
+ self.run_test(model, (random_data, empty_tensor))
+
@skipIfUnsupportedMinOpsetVersion(11)
def test_list_append_in_block(self):
class ListModel(torch.nn.Module):
diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
index 031ad2b..18eb78f 100644
--- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
+++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
@@ -236,6 +236,11 @@
// NOTE: the output order is deliberately changed to match expected order
// since onnx loop requires scan outputs to be the last outputs.
auto new_outputs = ConvertSequenceDependencies(node, opset_version);
+
+ // Copy type of block output to node output.
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
+ node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
+ }
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
return new_outputs;
}
@@ -375,6 +380,11 @@
auto* graph = if_node->owningGraph();
FixupONNXSubblockOutputs(node);
ONNXFixupUninitializedOutput(if_node);
+ // Copy type of block output to node output.
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
+ node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
+ }
+
GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
return if_node->outputs().vec();
}
diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
index bc5e6cb..4a86ed3 100644
--- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
+++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
@@ -19,17 +19,98 @@
const std::set<c10::Symbol> inplace_ops =
{aten::append, aten::index_put_, aten::pop, aten::insert, aten::Delete};
-bool IsInplaceNode(const Node* n) {
- if (inplace_ops.find(n->kind()) != inplace_ops.end()) {
- return true;
- }
+// InplaceConverter defines a set of functions that together enables the
+// conversion from prim::GetAttr, prim::SetAttr, and ATen in-place operators to
+// ONNX out-place operators.
+struct InplaceConverter {
+ InplaceConverter(
+ std::shared_ptr<Graph> graph,
+ MutationRemover* mr,
+ Module* model = nullptr)
+ : graph_(std::move(graph)), mr_(mr), module_(model) {}
- if (n->kind() == Symbol::fromQualString("onnx::Placeholder") &&
- n->s(attr::name) == "index_put_") {
- return true;
- }
+ void convertMutationForONNX();
- return false;
+ private:
+ void gatherAttrNameInitialValueMap(
+ Block* block,
+ std::unordered_map<std::string, Value*>& attr_name_value_map,
+ std::unordered_map<Node*, std::string>& attr_node_fullname_map);
+ void replaceAttrWithInplaceOps(
+ Block* block,
+ const std::unordered_map<std::string, Value*>& attr_name_value_map,
+ const std::unordered_map<Node*, std::string>& attr_node_fullname_map);
+
+ void convertInplaceOpsAndTrackAlias();
+ void convertInplaceOpsAndTrackAlias(Block* block);
+
+ void correctAliasReferences();
+ void correctAliasReferences(Block* block);
+ void correctAliasReferences(Node* n);
+
+ void convertGetSetAttrToInplaceOps(Block* block);
+
+ // ValueTracker provides apis to record aliases for a single value,
+ // and to retrieve the correct alias of any given value based on the location
+ // in the graph it is used.
+ struct ValueTracker {
+ ValueTracker() : graph_(nullptr) {}
+
+ void init(const std::shared_ptr<Graph>& graph);
+ void recordSetValue(Value* old_v, Value* new_v);
+ Value* findAliasForValueAtNode(Value* v, const Node* n) const;
+
+ std::string toString() const;
+
+ private:
+ std::shared_ptr<Graph> graph_;
+
+ // Map from aliases to root value.
+ // A single value can have multiple aliases throughout the graph,
+ // created by inplace operators, and preserved through loop carried
+ // input/output. For each such value, its first occurance will be set as
+ // root value.
+ std::unordered_map<Value*, Value*> alias_to_value_;
+
+ // Sort the alias based on their order in graph.
+ // A tie can happen when two distinct aliases belong to different blocks,
+ // while having the same ancestor node. The unique id is used as tie
+ // breaker, otherwise the two aliases will be considered equal to each
+ // other. aliasComp must satisfy strict weak ordering.
+ struct aliasComp {
+ bool operator()(const Value* a, const Value* b) const {
+ auto* n_a = a->node();
+ auto* n_b = b->node();
+ if (n_a == n_b) {
+ return false;
+ }
+ auto a_b = n_a->isBefore(n_b);
+ auto b_a = n_b->isBefore(n_a);
+ if (a_b == b_a) {
+ return a->unique() < b->unique();
+ }
+ return a_b;
+ }
+ };
+ // Map from root value to aliases sorted by their order in graph.
+ std::unordered_map<Value*, std::set<Value*, aliasComp>>
+ value_to_sorted_aliases_;
+ };
+
+ std::shared_ptr<Graph> graph_;
+ MutationRemover* mr_;
+ Module* module_;
+ ValueTracker vt_;
+};
+
+bool isAncestor(const Block* a, const Block* b) {
+ while (b && b->owningNode()) {
+ if (a == b) {
+ return true;
+ }
+ b = b->owningNode()->owningBlock();
+ }
+ return a == b;
}
Node* addDummyClone(
@@ -66,332 +147,59 @@
return newNode;
}
-// Check If then/else blocks to match the number of outputs.
-// If the number of block outputs do not match, insert a dummy
-// constant of corresponding shape and type.
-Value* MatchIfBlocksOutputForValue(
- Value* orig_data,
- Block* outer_block,
- Value* origOutput) {
- if (outer_block->owningNode()->kind() == prim::Loop)
- return outer_block->owningNode()->outputs().at(
- outer_block->owningNode()->outputs().size() - 1);
-
- if (outer_block->owningNode()->kind() != prim::If)
- return nullptr;
- size_t output_size = outer_block->outputs().size();
-
- for (size_t i = 0; i < output_size - 1; i++) {
- if (outer_block->outputs().at(i)->debugNameBase() ==
- origOutput->debugNameBase()) { // Check debug names
- outer_block->replaceOutput(i, outer_block->outputs().at(output_size - 1));
- outer_block->eraseOutput(output_size - 1);
- outer_block->owningNode()->eraseOutput(output_size - 1);
- return outer_block->owningNode()->outputs().at(i);
- }
- }
-
- for (Block* b : outer_block->owningNode()->blocks()) {
- if (b->outputs().size() < output_size) {
- auto clone_node =
- addDummyClone(b->owningGraph(), orig_data, false, b->return_node());
- b->registerOutput(clone_node->output());
- b->outputs()
- .at(b->outputs().size() - 1)
- ->copyMetadata(
- outer_block->outputs().at(output_size - 1)); // Copy debug names
- }
- }
- return outer_block->owningNode()->outputs().at(output_size - 1);
-}
-
-// clang-format off
-// Register inplace op node inputs/outputs through the blocks.
-// Eg. The IR before updating:
-//%23 : bool = aten::eq(%22, %13)
-// = prim::If(%23)
-// block0():
-// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-// %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
-// %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-// %27 : Tensor = aten::copy_(%26, %25, %9)
-// -> ()
-// block1():
-// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-// %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
-// %30: Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-// %31 : Tensor = aten::copy_(%30, %29, %9)
-// -> ()
-// After updating:
-//%23 : bool = aten::eq(%22, %13)
-//%51 : Tensor = prim::If(%23)
-// block0():
-// %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-// %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
-// %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-// %32 : Tensor?[] = prim::ListConstruct()
-// %33 : Tensor = aten::expand_as(%25, %26)
-// %38 : int = prim::Constant[value=0]()
-// %39 : int = aten::size(%state.1, %38)
-// %40 : int = prim::Constant[value=4]()
-// %41 : None = prim::Constant()
-// %42 : None = prim::Constant()
-// %43 : None = prim::Constant()
-// %44 : Tensor = aten::arange(%39, %40, %41, %42, %43)
-// %45 : int = prim::Constant[value=0]()
-// %46 : Tensor = aten::slice(%44, %45, %13, %10, %11)
-// %47 : int[] = prim::Constant[value=[-1]]()
-// %48 : Tensor = aten::view(%46, %47)
-// %49 : Tensor?[] = prim::ListConstruct(%48)
-// %50 : Tensor = aten::index_put(%state.1, %49, %33, %9)
-// -> (%50)
-// block1():
-// %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-// %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
-// %30 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-// %35 : Tensor?[] = prim::ListConstruct()
-// %36 : Tensor = aten::expand_as(%29, %30)
-// %52 : int = prim::Constant[value=0]()
-// %53 : int = aten::size(%state.1, %52)
-// %54 : int = prim::Constant[value=4]()
-// %55 : None = prim::Constant()
-// %56 : None = prim::Constant()
-// %57 : None = prim::Constant()
-// %58 : Tensor = aten::arange(%53, %54, %55, %56, %57)
-// %59 : int = prim::Constant[value=0]()
-// %60 : Tensor = aten::slice(%58, %59, %13, %10, %11)
-// %61 : int[] = prim::Constant[value=[-1]]()
-// %62 : Tensor = aten::view(%60, %61)
-// %63 : Tensor?[] = prim::ListConstruct(%62)
-// %64 : Tensor = aten::index_put(%state.1, %63, %36, %9)
-// -> (%64)
-// clang-format on
-void RegisterInplaceNodeInIfBlocks(
- Value* orig_data,
- Value* new_data,
- const std::string& output_name) {
- auto outer_block = new_data->node()->owningBlock();
- auto initial_block_node = outer_block->owningNode();
-
- if ((nullptr == initial_block_node) ||
- (initial_block_node->kind() != prim::If)) {
- return;
- }
-
- auto next_block_node = initial_block_node;
- new_data->setDebugName("_output_" + output_name);
- outer_block->registerOutput(new_data);
- // Block has a new output. Add the output for the prim::If node.
- if (next_block_node->outputs().size() < outer_block->outputs().size())
- next_block_node->addOutput()->copyMetadata(new_data);
-
- auto next_block = next_block_node->owningBlock();
- while (nullptr != next_block->owningNode() &&
- next_block != orig_data->node()->owningBlock()) {
- next_block->registerOutput(next_block_node->output(0));
- next_block_node = next_block->owningNode();
- // Block has a new output. Add the output for the prim::If node.
- if (next_block_node->outputs().size() < next_block->outputs().size())
- next_block_node->addOutput()->setType(new_data->type());
- next_block = next_block_node->owningBlock();
- }
- orig_data->replaceAllUsesAfterNodeWith(
- next_block_node,
- next_block_node->outputs().at(next_block_node->outputs().size() - 1));
-}
-
-// clang-format off
-// Register inplace op node inputs/outputs through the blocks.
-// Eg. The IR before updating:
-// = prim::Loop(%10, %27)
-// block0(%stream_idx.1 : int):
-// = prim::Loop(%9, %27)
-// block0(%i.1 : int):
-// %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1)
-// %41 : Tensor = aten::copy_(%37, %40, %25)
-// -> (%27)
-// -> (%27)
-// After updating:
-// %62 : Tensor = prim::Loop(%10, %27, %bias.2)
-// block0(%stream_idx.1 : int, %bias.3 : Tensor):
-// %61 : Tensor = prim::Loop(%9, %27, %bias.3)
-// block0(%i.1 : int, %bias.1 : Tensor):
-// %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1)
-// %59 : Tensor?[] = prim::ListConstruct(%55, %58)
-// %60 : Tensor = aten::index_put(%bias.1, %59, %45, %25)
-// -> (%27, %60)
-// -> (%27, %61)
-// clang-format on
-void RegisterInplaceNodeInLoopBlocks(Value* orig_data, Value* new_data) {
- Node* inplace_node = new_data->node();
- Block* outer_block = inplace_node->owningBlock();
- Node* outer_block_node = outer_block->owningNode();
-
- if (nullptr == outer_block_node) {
- return;
- }
-
- if (outer_block_node->kind() != prim::Loop)
- return;
-
- outer_block->registerOutput(new_data);
- std::vector<std::pair<Block*, Node*>> node_list = {
- std::make_pair(outer_block, outer_block_node)};
-
- outer_block_node->addOutput()->setType(new_data->type());
- auto next_block = outer_block_node->owningBlock();
- auto next_node = outer_block_node;
-
- while (nullptr != next_block->owningNode() &&
- next_block != orig_data->node()->owningBlock()) {
- outer_block = next_block;
- outer_block->registerOutput(
- next_node->outputs().at(next_node->outputs().size() - 1));
- next_node = outer_block->owningNode();
- next_node->addOutput()->setType(new_data->type());
- next_block = next_node->owningBlock();
- if (next_node->kind() == prim::Loop) // Do not register input if nested in
- // If block. Register in Loop blocks.
- node_list.emplace_back(std::make_pair(outer_block, next_node));
- }
-
- // Register inplace node inputs through the blocks.
- auto next_data = orig_data;
- while (!node_list.empty()) {
- auto cur_pair = node_list.back();
- // Add input to current node.
- cur_pair.second->addInput(next_data);
- // Add input to current block.
- auto cur_input = cur_pair.first->addInput();
- cur_input->setType(next_data->type());
- next_data = cur_input;
- node_list.pop_back();
- }
-
- // Update inplace node inputs inside the outer most block.
- outer_block_node = outer_block->owningNode();
- auto prev_data =
- outer_block_node->inputs().at(outer_block_node->inputs().size() - 1);
- for (auto node : inplace_node->owningBlock()->nodes()) {
- size_t idx = 0;
- for (auto inputs_ : node->inputs()) {
- if (inputs_ == prev_data) {
- node->replaceInput(idx, next_data);
- break;
- }
- idx++;
- }
- }
-
- orig_data->replaceAllUsesAfterNodeWith(
- next_node->outputs().at(0)->node(),
- next_node->outputs().at(next_node->outputs().size() - 1));
-}
-
-// Register inplace op node inputs/outputs through the blocks.
-void RegisterInplaceNodeInBlocks(Value* orig_data, Value* new_data) {
- Node* inplace_node = new_data->node();
- Block* outer_block = inplace_node->owningBlock();
- Node* outer_block_node = outer_block->owningNode();
-
- if (outer_block_node == nullptr)
- return;
-
- // Check if the value is already registered in the block
- bool registered = false;
- while (IsInplaceNode(orig_data->node())) {
- orig_data = orig_data->node()->inputs().at(0);
- }
- for (auto use : orig_data->uses()) {
- if ((use.user->owningBlock() == outer_block) &&
- (use.user->isAfter(inplace_node))) {
- size_t idx = 0;
- for (auto input_ : use.user->inputs()) {
- if (input_ == orig_data) {
- use.user->replaceInput(idx, new_data);
- registered = true;
- }
- idx++;
- }
- }
- }
- if (registered)
- return;
-
- // Register inplace node outputs through the blocks.
- RegisterInplaceNodeInLoopBlocks(orig_data, new_data);
-
- RegisterInplaceNodeInIfBlocks(orig_data, new_data, orig_data->debugName());
-
- while (nullptr != outer_block->owningNode() &&
- outer_block != orig_data->node()->owningBlock()) {
- MatchIfBlocksOutputForValue(orig_data, outer_block, new_data);
- outer_block = outer_block->owningNode()->owningBlock();
- }
-}
-
-void PrepareIndexPutForONNX(Node* node) {
+std::pair<Value*, Value*> PrepareIndexPutForONNX(Node* node) {
TORCH_INTERNAL_ASSERT(
node->kind() == aten::index_put || node->kind() == aten::index_put_);
auto placeholder_node = EncapsulatePatternIntoSubblock(node).value();
- if (node->kind() == aten::index_put_) {
- auto orig_data = placeholder_node->input();
- auto new_data = placeholder_node->output();
-
- if (nullptr == placeholder_node->owningBlock()->owningNode()) {
- orig_data->replaceAllUsesAfterNodeWith(placeholder_node, new_data);
- return;
- }
- RegisterInplaceNodeInBlocks(orig_data, new_data);
- }
+ node->destroy();
+ return std::make_pair(placeholder_node->input(0), placeholder_node->output());
}
-void PrepareCopyForONNX(Node* node) {
- if (node->kind() == aten::copy_) {
- // aten::copy_ can be viewed as a special case of index_put, where the
- // tensor indices input is empty.
- // Remove aten::copy_, and replace it with index_put.
- // 1. create an empty listConstruct node as indices input for index_put.
- // 2. create index_put node.
+std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
+ TORCH_INTERNAL_ASSERT(node->kind() == aten::copy_);
+ // aten::copy_ can be viewed as a special case of index_put, where the
+ // tensor indices input is empty.
+ // Remove aten::copy_, and replace it with index_put.
+ // 1. create an empty listConstruct node as indices input for index_put.
+ // 2. create index_put node.
- // Tracing aten::copy_ broadcasts the rhs values.
- // 3. Apply broadcasting for scripting.
- WithInsertPoint guard(node);
- auto graph = node->owningGraph();
- auto dummy_list =
- graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
- ->output();
+ // Tracing aten::copy_ broadcasts the rhs values.
+ // 3. Apply broadcasting for scripting.
+ WithInsertPoint guard(node);
+ auto graph = node->owningGraph();
+ auto dummy_list =
+ graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
+ ->output();
- auto expanded_value =
- graph->insert(aten::expand_as, {node->input(1), node->input(0)});
- expanded_value->node()->setSourceRange(node->sourceRange());
- expanded_value->copyMetadata(node->input(1));
+ auto expanded_value =
+ graph->insert(aten::expand_as, {node->input(1), node->input(0)});
+ expanded_value->node()->setSourceRange(node->sourceRange());
+ expanded_value->copyMetadata(node->input(1));
- auto index_put = graph->insert(
- aten::index_put_,
- {node->input(0), dummy_list, expanded_value, node->input(2)});
- index_put->node()->setSourceRange(node->sourceRange());
- index_put->copyMetadata(node->output());
- node->output()->replaceAllUsesWith(index_put);
+ auto index_put = graph->insert(
+ aten::index_put_,
+ {node->input(0), dummy_list, expanded_value, node->input(2)});
+ index_put->node()->setSourceRange(node->sourceRange());
+ index_put->copyMetadata(node->output());
+ node->output()->replaceAllUsesWith(index_put);
- PrepareIndexPutForONNX(index_put->node());
- }
+ node->destroy();
+
+ return PrepareIndexPutForONNX(index_put->node());
}
-void PrepareInplaceOpsInBlocksForONNX(Node* node) {
+std::pair<Value*, Value*> PrepareInplaceOpsInBlocksForONNX(Node* node) {
if (!node->kind().is_aten())
- return;
+ return {};
auto name = node->schema().name();
bool inplace_op = name.at(name.size() - 1) == '_';
if (!inplace_op)
- return;
+ return {};
auto new_schema = name.substr(0, name.size() - 1);
Node* input_node = node->inputs().at(0)->node();
- if (input_node->kind() != aten::select && input_node->kind() != aten::slice)
- return;
auto graph = node->owningGraph();
auto new_node = graph->create(Symbol::fromQualString(new_schema), 1);
@@ -401,17 +209,26 @@
new_node->output()->setType(node->output()->type());
new_node->insertBefore(node);
new_node->setSourceRange(node->sourceRange());
+ node->replaceAllUsesWith(new_node);
+ node->destroy();
- auto false_val_ = graph->insertConstant(false);
+ if (input_node->kind() == aten::select || input_node->kind() == aten::slice) {
+ // Cases from a[i] = x. Convert to copy_ and eventually index_put_.
+ WithInsertPoint guard(new_node);
+ auto false_val_ = graph->insertConstant(false);
- auto new_copy = graph->create(aten::copy_, 1);
- new_copy->addInput(input_node->output());
- new_copy->addInput(new_node->output());
- new_copy->addInput(false_val_);
- new_copy->insertBefore(node);
- new_copy->setSourceRange(node->sourceRange());
+ auto new_copy = graph->create(aten::copy_, 1);
+ new_copy->addInput(new_node->inputs().at(0));
+ new_copy->addInput(new_node->output());
+ new_copy->addInput(false_val_);
+ new_copy->insertAfter(new_node);
+ new_copy->setSourceRange(new_node->sourceRange());
- PrepareCopyForONNX(new_copy);
+ return PrepareCopyForONNX(new_copy);
+ } else {
+ // Direct aliasing, the node is a standalone inplace op.
+ return std::make_pair(new_node->input(0), new_node->output());
+ }
}
// aten::pop is inplace. The tensor list input is updated.
@@ -419,54 +236,43 @@
// aten::pop. Then it makes the original aten::pop operator return the updated
// tensor list, and replaces all later uses of that tensor list with this new
// output.
-static void PrepareListPopForONNX(Node* n) {
- if (n->kind() == aten::pop) {
- // %ten : Tensor = aten::pop(%seq, %pos)
- // Convert to
- // %ten : Tensor = aten::__getitem__(%seq, %pos)
- // %new_seq : Tensor[] = aten::pop(%seq, %pos)
- // And replace all uses of %seq afterwards with %new_seq
- Node* getitem_node =
- n->owningGraph()->create(aten::__getitem__, {n->inputs()});
- getitem_node->output()->setType(n->output()->type());
- getitem_node->insertBefore(n);
- n->output()->replaceAllUsesWith(getitem_node->output());
- n->output()->setType(n->inputs().at(0)->type());
+static std::pair<Value*, Value*> PrepareListPopForONNX(Node* n) {
+ TORCH_INTERNAL_ASSERT(n->kind() == aten::pop);
+ // %ten : Tensor = aten::pop(%seq, %pos)
+ // Convert to
+ // %ten : Tensor = aten::__getitem__(%seq, %pos)
+ // %new_seq : Tensor[] = aten::pop(%seq, %pos)
+ // And replace all uses of %seq afterwards with %new_seq
+ Node* getitem_node =
+ n->owningGraph()->create(aten::__getitem__, {n->inputs()});
+ getitem_node->output()->setType(n->output()->type());
+ getitem_node->insertBefore(n);
+ n->output()->replaceAllUsesWith(getitem_node->output());
+ n->output()->setType(n->inputs().at(0)->type());
- if (nullptr == n->owningBlock()->owningNode()) {
- n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
- return;
- }
- RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
- }
+ return std::make_pair(n->input(0), n->output());
}
-static void PrepareListDeleteForONNX(Node* n) {
- if (n->kind() == aten::Delete) {
+static std::pair<Value*, Value*> PrepareListDeleteForONNX(Node* n) {
+ TORCH_INTERNAL_ASSERT(n->kind() == aten::Delete);
+ n->addOutput();
+ n->output()->setType(n->inputs().at(0)->type());
+
+ return std::make_pair(n->input(0), n->output());
+}
+
+static std::pair<Value*, Value*> PrepareListAppendAndInsertForONNX(Node* n) {
+ TORCH_INTERNAL_ASSERT(n->kind() == aten::insert || n->kind() == aten::append);
+ if (n->outputs().size() == 0) {
n->addOutput();
n->output()->setType(n->inputs().at(0)->type());
-
- if (nullptr == n->owningBlock()->owningNode()) {
- n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
- return;
- }
- RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
}
+ return std::make_pair(n->input(0), n->output());
}
-static void PrepareListAppendAndInsertForONNX(Node* n) {
- if (n->kind() == aten::insert || n->kind() == aten::append) {
- if (n->outputs().size() == 0) {
- n->addOutput();
- n->output()->setType(n->inputs().at(0)->type());
- }
-
- if (nullptr == n->owningBlock()->owningNode()) {
- n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
- return;
- }
- RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
- }
+static std::pair<Value*, Value*> PrepareListSetItemForONNX(Node* n) {
+ TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item);
+ return std::make_pair(n->input(0), n->output());
}
// Remove Mutation pass does not handle mutation on block inputs.
@@ -567,256 +373,466 @@
name);
}
-Value* registerSetAttrInBlocks(
- const std::shared_ptr<Graph>& graph,
- Block* block,
- Node* cloneNode,
- Value* origValue,
- const std::string& output_name) {
- RegisterInplaceNodeInLoopBlocks(origValue, cloneNode->output());
- RegisterInplaceNodeInIfBlocks(origValue, cloneNode->output(), output_name);
-
- Value* output = nullptr;
- while (nullptr != block->owningNode() &&
- block != origValue->node()->owningBlock()) {
- output = MatchIfBlocksOutputForValue(origValue, block, cloneNode->output());
- block = block->owningNode()->owningBlock();
- }
- return output;
+void InplaceConverter::ValueTracker::init(const std::shared_ptr<Graph>& graph) {
+ alias_to_value_ = {};
+ value_to_sorted_aliases_ = {};
+ graph_ = graph;
}
-// clang-format off
-// The trackAndRegisterAttributesInBlocks function tracks any instances
-// of getAttr and setAttr in a sub-block and capture these nodes as inpalce
-// read/write ops. This pass captures the output of setAttr in sub-block outputs
-// so that it gets reflected into the outer block.
-// Also, the pass matched the number of If sub-block outputs
-// if a value is updated in one branch, but no updated on the other branch.
-// For example:
-//= prim::If(%12)
-// block0():
-// %13 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %b.1 : Tensor? = prim::GetAttr[name="bias"](%13)
-// ...
-// %18 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %19 : Tensor = aten::add(%anchors.1, %b, %6)
-// = prim::SetAttr[name="bias"](%18, %19)
-// -> ()
-// block1():
-// %20 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %21 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %22 : Tensor = prim::GetAttr[name="weight"](%21)
-// %23 : Tensor = aten::slice(%22, %7, %7, %8, %6)
-// = prim::SetAttr[name="bias"](%20, %23)
-// -> ()
-// After the pass
-//%_output_conv.bias.3 : Tensor = prim::If(%12)
-// block0():
-// ...
-// %18 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %19 : Tensor = aten::add(%anchors.1, %b, %6)
-// %_output_conv.bias.2 : Tensor = aten::clone(%19, %26)
-// -> (%_output_conv.bias.2)
-// block1():
-// %20 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-// %23 : Tensor = aten::slice(%conv.weight, %7, %7, %8, %6)
-// %31 : None = prim::Constant()
-// %_output_conv.bias.4 : Tensor = aten::clone(%23, %31)
-// -> (%_output_conv.bias.4)
-// clang-format on
-void trackAndRegisterAttributesInBlocks(
- Node* n,
- const std::shared_ptr<Graph>& graph,
- const Module& module_,
- std::unordered_map<std::string, Value*>& allAttrValues,
- std::unordered_map<std::string, Value*>& setAttrValues,
- std::unordered_map<std::string, Value*>& nextSetAttrValues) {
- if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
- return;
+std::string InplaceConverter::ValueTracker::toString() const {
+ std::stringstream ss;
- auto name = n->s(attr::name);
- auto attrModule = module_;
- Value* paramConst = nullptr;
-
- auto moduleNames =
- findSubModuleAttr(n->inputs().at(0), name, attrModule, graph);
-
- std::string fullName("");
- for (auto& name : moduleNames) {
- fullName += name + '.';
+ // ss << "Current graph: " << graph_->toString() << std::endl;
+ ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
+ << std::endl;
+ ss << "value_to_sorted_aliases_: " << std::endl;
+ size_t idx = 0;
+ for (const auto& it : value_to_sorted_aliases_) {
+ ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl;
+ ss << " Mapping to ";
+ for (auto v : it.second) {
+ ss << v->debugName() << " ";
+ }
+ ss << std::endl;
+ idx++;
}
- fullName += name;
- if (allAttrValues.find(fullName) == allAttrValues.end() &&
- attrModule.hasattr(name)) {
- auto attr = attrModule.attr(name);
- auto type = attrModule.type();
- auto slot = *type->findAttributeSlot(name);
+ ss << "alias_to_value_: " << std::endl;
+ for (auto it : alias_to_value_) {
+ ss << " Alias " << it.first->debugName();
+ ss << " map to " << it.second->debugName() << std::endl;
+ }
- // Add model_parameters and model_buffers as model inputs. Order is
- // preserved based on the appearance in the graph.
- if (type->is_parameter(slot) || type->is_buffer(slot) ||
- (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
- if (allAttrValues.find(fullName) == allAttrValues.end()) {
- paramConst = findArgumentAsInputParam(graph, fullName, attr);
- allAttrValues.insert({fullName, paramConst});
- }
- } else if (auto attrVal = tryInsertConstant(*graph, attr)) {
- for (size_t i = 0; i < type->getAttributes().size(); i++) {
- if (type->getAttributeName(i) == name) {
- paramConst = *attrVal;
- allAttrValues.insert({fullName, paramConst});
+ return ss.str();
+}
+
+void InplaceConverter::ValueTracker::recordSetValue(
+ Value* old_v,
+ Value* new_v) {
+ GRAPH_UPDATE(
+ "Calling recordSetValue with old_v: ",
+ old_v->debugName(),
+ " new_v: ",
+ new_v->debugName());
+ GRAPH_UPDATE(this->toString());
+ auto* n = new_v->node();
+ auto* owning_block = n->owningBlock();
+
+ if (alias_to_value_.find(old_v) == alias_to_value_.end()) {
+ alias_to_value_[old_v] = old_v;
+ value_to_sorted_aliases_[old_v] = {old_v};
+ }
+
+ auto root_v = alias_to_value_[old_v];
+ alias_to_value_[new_v] = root_v;
+ auto& sorted_alias = value_to_sorted_aliases_[root_v];
+ sorted_alias.insert(new_v);
+
+ // check if new_v is created inside if or loop subblock.
+ auto* owning_blocknode = owning_block->owningNode();
+ if (nullptr == owning_blocknode) {
+ return;
+ }
+ auto owning_block_nkind = owning_blocknode->kind();
+ if (owning_block_nkind != prim::Loop && owning_block_nkind != prim::If) {
+ return;
+ }
+
+ bool registered = std::any_of(
+ owning_block->outputs().begin(),
+ owning_block->outputs().end(),
+ [&sorted_alias](Value* out) {
+ return std::any_of(
+ sorted_alias.begin(), sorted_alias.end(), [&out](Value* alias) {
+ return alias == out;
+ });
+ });
+
+ bool from_outer_alias = std::any_of(
+ sorted_alias.begin(),
+ sorted_alias.end(),
+ [&owning_blocknode](Value* alias) {
+ return isAncestor(
+ alias->node()->owningBlock(), owning_blocknode->owningBlock());
+ });
+
+ // The data of this value has been changed.
+ // If this value has alias from outer block,
+ // then the update must be reflected back to outside.
+ // Thus it needs to be registered as a subblock output.
+ // This step can be skipped if other alias of this value has already been
+ // registered as sublock output.
+ if (!registered && from_outer_alias) {
+ if (owning_block_nkind == prim::Loop) {
+ owning_block->registerOutput(new_v);
+ auto new_block_in = owning_block->addInput();
+ new_block_in->setType(new_v->type());
+ sorted_alias.insert(new_block_in);
+ alias_to_value_[new_block_in] = root_v;
+ owning_blocknode->addInput(root_v);
+ } else if (owning_block_nkind == prim::If) {
+ for (auto* if_sub_block : owning_blocknode->blocks()) {
+ if (owning_block == if_sub_block) {
+ if_sub_block->registerOutput(new_v);
+ } else {
+ if_sub_block->registerOutput(root_v);
}
}
- } else {
- GRAPH_DEBUG(
- attr.type()->cast<ClassType>() ? "" : "attribute: ",
- name,
- " is not materializable.");
- return;
}
+ auto* new_blocknode_out = owning_blocknode->addOutput();
+ new_blocknode_out->setType(new_v->type());
+ recordSetValue(root_v, new_blocknode_out);
}
- if (n->kind() == prim::SetAttr) { // Handle SetAttr node
- if (attrModule.hasattr(name)) {
- // If inside a block, keep the output value to register in block
- // output.
- auto block_ = n->owningBlock();
- Node* cloneNode =
- addDummyClone(block_->owningGraph(), n->inputs().at(1), true, n);
- if (block_->owningNode() &&
- (block_->owningNode()->kind() == prim::If ||
- block_->owningNode()->kind() == prim::Loop)) {
- auto attrValue = (setAttrValues.find(fullName) != setAttrValues.end())
- ? setAttrValues[fullName]
- : allAttrValues[fullName];
-
- auto blockOutput = registerSetAttrInBlocks(
- graph, block_, cloneNode, attrValue, fullName);
-
- nextSetAttrValues[fullName] = blockOutput;
- }
- // SetAttr writes a value to an attr. Keep this
- // in the setAttrValues map.
- setAttrValues[fullName] = cloneNode->output();
- }
- } else if (n->kind() == prim::GetAttr) { // Handle GetAttr node
- if (setAttrValues.find(fullName) != setAttrValues.end()) {
- // Attr has been set earlier in the graph.
- // Read its value from setAttrValues map.
- auto set_attr_node_input = setAttrValues[fullName];
- // Clone SetAttr input
- n->output()->replaceAllUsesAfterNodeWith(n, set_attr_node_input);
- } else if (allAttrValues.find(fullName) != allAttrValues.end()) {
- // Attr has not been set earlier in the graph. Replace it with the
- // graph parameter if exists.
- n->output()->replaceAllUsesWith(allAttrValues[fullName]);
- n->removeAllInputs();
- }
- }
+ GRAPH_UPDATE(
+ "After recordSetValue for in: ",
+ old_v->debugName(),
+ ", out: ",
+ new_v->debugName(),
+ ". tracker status:");
+ GRAPH_UPDATE(this->toString());
}
-// clang-format off
-// The registerInplaceOpAsBlockOutputs function tracks inplace op
-// (like aten::copy_ or aten::append) outputs as sub-block output.
-// Also, match the number of If sub-block outputs
-// if a value is updated in one branch, but no updated on the other branch.
-// For example:
-// = prim::If(%30)
-// block0():
-// ...
-// %35 : Tensor = aten::copy_(%state_copy.1, %33, %12)
-// -> ()
-// block1():
-// ...
-// %40 : Tensor = aten::copy_(%state.1, %38, %12)
-// -> ()
-//
-// After the pass
-//%_output_state_copy.1 : Tensor, %_output_state.1 : Tensor = prim::If(%30)
-// block0():
-// %_output_state.2 : Tensor = aten::clone(%state.1, %59)
-// ...
-// %_output_state_copy.3 : Tensor = onnx::Placeholder[name="index_put_"](%state_copy.1)...
-// ...
-// -> (%_output_state_copy.3, %_output_state.2)
-// block1():
-// %50 : None = prim::Constant()
-// %_output_state_copy.2 : Tensor = aten::clone(%state_copy.1, %50)
-// ...
-// %_output_state.3 : Tensor = onnx::Placeholder[name="index_put_"](%state.1)...
-// ...
-// -> (%_output_state_copy.2, %_output_state.3)
-std::unordered_map<std::string, Value*> registerInplaceOpAsBlockOutputs(
- Block* block,
- const std::shared_ptr<Graph>& graph,
- std::unordered_map<std::string, Value*>& allAttrValues,
- std::unordered_map<std::string, Value*>& setAttrValues,
- MutationRemover& mr,
- Module* module_ = nullptr) {
- Node* m = *block->nodes().begin();
- WithInsertPoint guard(m);
- std::unordered_map<std::string, Value*> nextSetAttrValues = {};
+// Based on current value aliases record, pass over graph and correct alias
+// reference for all the nodes.
+void InplaceConverter::correctAliasReferences() {
+ correctAliasReferences(graph_->block());
+}
+void InplaceConverter::correctAliasReferences(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++; // node n can be destroyed
- if (nullptr != module_ &&
- (n->kind() == prim::GetAttr || n->kind() == prim::SetAttr)) {
- Module moduleClone = (*module_);
- trackAndRegisterAttributesInBlocks(
- n,
- graph,
- moduleClone,
- allAttrValues,
- setAttrValues,
- nextSetAttrValues);
- } else if (n->kind() == aten::copy_) {
- PrepareCopyForONNX(n);
- } else if (n->kind() == aten::index_put || n->kind() == aten::index_put_) {
- PrepareIndexPutForONNX(n);
- } else if (mr.inplaceOpVariant(n)) {
- PrepareInplaceOpsInBlocksForONNX(n);
- } else if (n->kind() == aten::pop) {
- PrepareListPopForONNX(n);
- } else if (n->kind() == aten::insert || n->kind() == aten::append) {
- PrepareListAppendAndInsertForONNX(n);
- } else if (n->kind() == aten::Delete) {
- PrepareListDeleteForONNX(n);
- } else { // for prim::If and prim::Loop nodes with blocks.
- for (Block* sub_block : n->blocks()) {
- std::unordered_map<std::string, Value*> map_ =
- registerInplaceOpAsBlockOutputs(
- sub_block, graph, allAttrValues, setAttrValues, mr, module_);
- std::unordered_map<std::string, Value*>::iterator mapIt;
- for (mapIt = map_.begin(); mapIt != map_.end(); mapIt++) {
- setAttrValues[mapIt->first] = mapIt->second;
- }
+ correctAliasReferences(n);
+
+ auto nkind = n->kind();
+ if (nkind == prim::If || nkind == prim::Loop) {
+ for (auto* sub_block : n->blocks()) {
+ correctAliasReferences(sub_block);
}
}
}
- return nextSetAttrValues;
+ correctAliasReferences(block->return_node());
}
-// Register Inplace Ops As Block Outputs
-// Inplace operations like aten::copy_ or aten::append that are inside
-// sub-blocks would require the output of the operation to be captured
-// as sub-block output, so that the inplace operation would be visible
-// to the outer block.
-// We also consider setAttr node an inplace op, and handle those
-// similarly by tracking the output as sub-block outputs.
-void RegisterInplaceOpAsBlockOutputs(
- Module* module,
- const std::shared_ptr<Graph>& graph,
- MutationRemover& mr) {
- // A map of names and values of referenced attributes, to avoid duplicates.
- std::unordered_map<std::string, Value*> allAttrValues = {};
- // A map of names and values of set attributes, to track mutations.
- std::unordered_map<std::string, Value*> setAttrValues = {};
+// For every input of Node n, find the correct alias representing that input.
+void InplaceConverter::correctAliasReferences(Node* n) {
+ for (size_t i = 0; i < n->inputs().size(); ++i) {
+ auto* in = n->input(i);
+ auto* alias = vt_.findAliasForValueAtNode(in, n);
- registerInplaceOpAsBlockOutputs(
- graph->block(), graph, allAttrValues, setAttrValues, mr, module);
+ if (alias != in) {
+ n->replaceInput(i, alias);
+ GRAPH_UPDATE(
+ "Replacing ",
+ in->debugName(),
+ " with ",
+ alias->debugName(),
+ " for ",
+ *n);
+ }
+ }
+}
+
+// Find the correct alias representing Value v at Node n.
+Value* InplaceConverter::ValueTracker::findAliasForValueAtNode(
+ Value* v,
+ const Node* n) const {
+ GRAPH_UPDATE("Finding alias for value:", v->debugName(), " at node ", *n);
+ if (alias_to_value_.find(v) == alias_to_value_.end()) {
+ // This value was not affected by any inplace operator.
+ return v;
+ }
+
+ auto* root_v = alias_to_value_.find(v)->second;
+ TORCH_INTERNAL_ASSERT(
+ value_to_sorted_aliases_.find(root_v) != value_to_sorted_aliases_.end());
+ const auto& aliases = value_to_sorted_aliases_.find(root_v)->second;
+
+ // alias is accessible only if
+ // 1. alias owning block is ancestor of n.
+ // 2. alias owning node is before n.
+ // return the last alias that satisfies this condition.
+ Value* found_alias = nullptr;
+ for (auto* alias : aliases) {
+ auto* alias_n = alias->node();
+ if (alias_n->isBefore(n) &&
+ isAncestor(alias_n->owningBlock(), n->owningBlock())) {
+ found_alias = alias;
+ }
+ }
+
+ TORCH_INTERNAL_ASSERT(
+ nullptr != found_alias,
+ "More details: \n",
+ n->sourceRange().str(),
+ "Input ",
+ v->debugName(),
+ " of node ",
+ *n,
+ " was modified by in-place operation, but we cannot find its updated value. ",
+ "Please report a bug to PyTorch, and/or try to avoid using in-place operators on this value.");
+
+ return found_alias;
+}
+
+// Pass over block, and gather the initial value for any attribute.
+// Also cache the full name of the attribute for every GetAttr/SetAttr node.
+void InplaceConverter::gatherAttrNameInitialValueMap(
+ Block* block,
+ std::unordered_map<std::string, Value*>& attr_name_value_map,
+ std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* n = *it;
+ it++; // node n can be destroyed
+
+ for (auto* sub_block : n->blocks()) {
+ gatherAttrNameInitialValueMap(
+ sub_block, attr_name_value_map, attr_node_fullname_map);
+ }
+
+ if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
+ continue;
+
+ auto name = n->s(attr::name);
+ auto attrModule = *module_;
+ Value* paramConst = nullptr;
+
+ auto moduleNames =
+ findSubModuleAttr(n->inputs().at(0), name, attrModule, graph_);
+
+ std::string fullName("");
+ for (auto& name : moduleNames) {
+ fullName += name + '.';
+ }
+ fullName += name;
+
+ attr_node_fullname_map.insert({n, fullName});
+
+ if (attr_name_value_map.find(fullName) == attr_name_value_map.end() &&
+ attrModule.hasattr(name)) {
+ auto attr = attrModule.attr(name);
+ auto type = attrModule.type();
+ auto slot = *type->findAttributeSlot(name);
+
+ // Add model_parameters and model_buffers as model inputs. Order is
+ // preserved based on the appearance in the graph.
+ WithInsertPoint guard(graph_->nodes().front());
+ if (type->is_parameter(slot) || type->is_buffer(slot) ||
+ (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
+ paramConst = findArgumentAsInputParam(graph_, fullName, attr);
+ attr_name_value_map.insert({fullName, paramConst});
+ } else if (auto attrVal = tryInsertConstant(*graph_, attr)) {
+ // TODO: Extend support for attribute of type List[Tensor] etc.
+ for (size_t i = 0; i < type->getAttributes().size(); i++) {
+ if (type->getAttributeName(i) == name) {
+ paramConst = *attrVal;
+ attr_name_value_map.insert({fullName, paramConst});
+ }
+ }
+ } else {
+ // If attribute is a custom class object, instead of primitive types,
+ // Tensor, or List/Tuple/Dict of Tensors.
+ GRAPH_DEBUG(
+ attr.type()->cast<ClassType>() ? "" : "attribute: ",
+ name,
+ " is not materializable.");
+ }
+ }
+
+ // Create dummy initial value, if initial value does not exist for this
+ // attribute.
+ if (attr_name_value_map.find(fullName) == attr_name_value_map.end()) {
+ auto* noneNode = graph_->create(prim::Constant);
+ noneNode->output()->setType(NoneType::get());
+ noneNode->insertBefore(graph_->nodes().front());
+ attr_name_value_map.insert({fullName, noneNode->output()});
+ }
+ }
+}
+
+// Replace prim::GetAttr and prim::SetAttr with ATen inplace operators.
+// Example graph:
+// clang-format off
+// Before graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
+// %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
+// ...
+// %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+// = prim::SetAttr[name="_bias"](%1, %10)
+// = prim::Loop(%5, %8)
+// block0(%i.1 : int):
+// %12 : bool = aten::eq(%i.1, %4)
+// = prim::If(%12)
+// block0():
+// = prim::Loop(%3, %8)
+// block0(%j : int):
+// %14 : Tensor = prim::GetAttr[name="_bias"](%1)
+// %15 : Tensor = aten::add_(%14, %2, %9)
+// = prim::SetAttr[name="_bias"](%1, %15)
+// -> (%8)
+// -> ()
+// block1():
+// %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+// = prim::SetAttr[name="_bias"](%1, %16)
+// -> ()
+// -> (%8)
+// %17 : Tensor = prim::GetAttr[name="_bias"](%1)
+// %18 : Tensor = aten::add(%17, %x.1, %9)
+// return (%18)
+//
+// After graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
+// %19 : Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1 1 [ CPUFloatType{2} ]]()
+// %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
+// ...
+// %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+// %26 : bool = prim::Constant[value=0]()
+// %27 : Tensor?[] = prim::ListConstruct()
+// %28 : Tensor = aten::index_put_(%19, %27, %10, %26)
+// = prim::Loop(%5, %8)
+// block0(%i.1 : int):
+// %12 : bool = aten::eq(%i.1, %4)
+// = prim::If(%12)
+// block0():
+// = prim::Loop(%3, %8)
+// block0(%j : int):
+// %15 : Tensor = aten::add_(%19, %2, %9)
+// %23 : bool = prim::Constant[value=0]()
+// %24 : Tensor?[] = prim::ListConstruct()
+// %25 : Tensor = aten::index_put_(%19, %24, %15, %23)
+// -> (%8)
+// -> ()
+// block1():
+// %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+// %20 : bool = prim::Constant[value=0]()
+// %21 : Tensor?[] = prim::ListConstruct()
+// %22 : Tensor = aten::index_put_(%19, %21, %16, %20)
+// -> ()
+// -> (%8)
+// %18 : Tensor = aten::add(%19, %x.1, %9)
+// return (%18)
+// clang-format on
+void InplaceConverter::replaceAttrWithInplaceOps(
+ Block* block,
+ const std::unordered_map<std::string, Value*>& attr_name_value_map,
+ const std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
+ for (const auto& pair : attr_node_fullname_map) {
+ auto* n = pair.first;
+ auto fullName = pair.second;
+ auto find_init_val = attr_name_value_map.find(fullName);
+ TORCH_INTERNAL_ASSERT(find_init_val != attr_name_value_map.end());
+
+ TORCH_INTERNAL_ASSERT(
+ n->kind() == prim::GetAttr || n->kind() == prim::SetAttr);
+ if (n->kind() == prim::SetAttr) {
+ // Convert SetAttr to inplace op.
+ // Directly convert to index_put_ instead of copy_, since we know expand
+ // is not required for value.
+ WithInsertPoint guard(n);
+ auto false_val_ = graph_->insertConstant(false);
+ auto dummy_list =
+ graph_->insertNode(graph_->createList(OptionalType::ofTensor(), {}))
+ ->output();
+
+ auto* index_put_node = graph_->create(aten::index_put_, 1);
+ index_put_node->addInput(find_init_val->second);
+ index_put_node->addInput(dummy_list);
+ index_put_node->addInput(n->input(1));
+ index_put_node->addInput(false_val_);
+ index_put_node->setSourceRange(n->sourceRange());
+ index_put_node->insertBefore(n);
+ } else if (n->kind() == prim::GetAttr) {
+ // Replace use of GetAttr with first seen alias (usually initial value) of
+ // that particular value. Correct alias at point of this node will be
+ // discovered and assigned in later pass.
+ n->output()->replaceAllUsesWith(find_init_val->second);
+ }
+
+ n->destroy();
+ }
+}
+
+void InplaceConverter::convertGetSetAttrToInplaceOps(Block* block) {
+ std::unordered_map<std::string, Value*> attr_name_value_map = {};
+ std::unordered_map<Node*, std::string> attr_node_fullname_map = {};
+ // First pass over graph, to gather all attribute names, and their intial
+ // values. Create dummy initial values for attributes if necessary. By the end
+ // of this pass, these dummy initial values should have zero uses, and can be
+ // safely removed. Otherwise it will imply error in model for using
+ // uninitialized values.
+ gatherAttrNameInitialValueMap(
+ block, attr_name_value_map, attr_node_fullname_map);
+ GRAPH_UPDATE("Graph after gatherAttrNameInitialValueMap", graph_->toString());
+
+ // Second pass over graph,
+ // replace GetAttr with first seen alias (usually initial value),
+ // and replace SetAttr with inplace op, updating new value onto first seen
+ // alias.
+ replaceAttrWithInplaceOps(block, attr_name_value_map, attr_node_fullname_map);
+}
+
+// Convert inplace ops to outplace version, and record the associated new alias
+// in ValueTracker.
+void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* n = *it;
+ it++; // node n can be destroyed
+
+ auto nkind = n->kind();
+ if (nkind == prim::If || nkind == prim::Loop) {
+ for (Block* sub_block : n->blocks()) {
+ convertInplaceOpsAndTrackAlias(sub_block);
+ }
+ } else {
+ Value *orig_data = nullptr, *new_out = nullptr;
+ if (nkind == aten::copy_) {
+ std::tie(orig_data, new_out) = PrepareCopyForONNX(n);
+ } else if (nkind == aten::index_put || nkind == aten::index_put_) {
+ std::tie(orig_data, new_out) = PrepareIndexPutForONNX(n);
+ if (nkind == aten::index_put) {
+ // special case, index_put is not inplace.
+ continue;
+ }
+ } else if (nkind == aten::insert || nkind == aten::append) {
+ std::tie(orig_data, new_out) = PrepareListAppendAndInsertForONNX(n);
+ } else if (mr_->inplaceOpVariant(n)) {
+ std::tie(orig_data, new_out) = PrepareInplaceOpsInBlocksForONNX(n);
+ } else if (nkind == aten::pop) {
+ std::tie(orig_data, new_out) = PrepareListPopForONNX(n);
+ } else if (nkind == aten::Delete) {
+ std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n);
+ } else if (nkind == aten::_set_item) {
+ std::tie(orig_data, new_out) = PrepareListSetItemForONNX(n);
+ } else {
+ // Not inplace op.
+ continue;
+ }
+
+ if (nullptr != orig_data && nullptr != new_out) {
+ vt_.recordSetValue(orig_data, new_out);
+ }
+ }
+ }
+}
+
+void InplaceConverter::convertInplaceOpsAndTrackAlias() {
+ convertInplaceOpsAndTrackAlias(graph_->block());
+ GRAPH_UPDATE(
+ "Graph after convertInplaceOpsAndTrackAlias: ", graph_->toString());
+ GRAPH_UPDATE(vt_.toString());
+}
+
+void InplaceConverter::convertMutationForONNX() {
+ // First pass to convert all prim::GetAttr and prim::SetAttr to ATen inplace
+ // operators.
+ convertGetSetAttrToInplaceOps(graph_->block());
+ GRAPH_UPDATE("Graph after convertGetSetAttrToInplaceOps", graph_->toString());
+ vt_.init(graph_);
+ // Second pass to convert all inplace operators to outplace version, and
+ // record the associated new alias in ValueTracker.
+ convertInplaceOpsAndTrackAlias();
+ // Third pass to check and correct alias reference for all the nodes.
+ correctAliasReferences();
}
} // namespace
@@ -829,7 +845,8 @@
PrepareForRemoveMutations(mr, graph->block());
RemoveTensorMutation(graph);
RemoveListMutation(graph);
- RegisterInplaceOpAsBlockOutputs(model, graph, mr);
+ InplaceConverter ic(graph, &mr, model);
+ ic.convertMutationForONNX();
}
} // namespace jit
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index 050b952..39282bdf 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -256,6 +256,9 @@
from torch.onnx.symbolic_opset9 import __getitem_ as getitem
return getitem(g, self, i)
+def _set_item(g, tensor_list, i, v):
+ tensor_list = g.op("SequenceErase", tensor_list, i)
+ return g.op("SequenceInsert", tensor_list, v, i)
def append(g, self, tensor):
return g.op("SequenceInsert", self, tensor)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 08f7309..c41e484 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1115,6 +1115,10 @@
return wrap_with_not
+def __not_(g, self):
+ return g.op("Not", self)
+
+
def eq(g, self, other):
return g.op("Equal", self, other)
@@ -1463,10 +1467,21 @@
def index_put(g, self, indices_list_value, values, accumulate):
- if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
+ if sym_help._is_packed_list(indices_list_value):
indices_list = sym_help._unpack_list(indices_list_value)
+ else:
+ indices_list = [indices_list_value]
+ if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s='index_put')
+
+ accumulate = sym_help._parse_arg(accumulate, 'b')
+
+ if len(indices_list) == 0:
+ if accumulate:
+ return add(g, self, values)
+ else:
+ return values
else:
sym_help._onnx_opset_unsupported('index_put', 9, 11)