Improve ONNX Loop export (#20445)
Summary:
~~This is work in progress due to its dependency on multiple pending PRs.~~
- [x] ONNX: Relax constraint on subgraph input/output type & shape check. https://github.com/onnx/onnx/pull/2009
- [x] PyTorch: Add infra to test_pytorch_onnx_caffe2.py to test ScriptModule models. https://github.com/pytorch/pytorch/pull/20256
This PR should partially resolve https://github.com/pytorch/pytorch/issues/17531. However, ideally we shouldn't need to put cast(and reshape) node to help the conversion for loop condition.
- Added cast node for condition values before entering loop node. The ONNX spec only accepts Bool type, while in PyTorch if the condition value is an output from other node it could potentially have any integral type.
- Tidying up the exported ONNX loop subgraph input type & shape. According to ONNX spec, input "M" is exported as 0-d scalar tensor with type int64. input "Cond" is exported as incomplete tensor of type Bool without shape information. This is because through out the iteration, the rank of condition value is dynamic, either 0-d or 1-d, as long as it holds a single value.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20445
Differential Revision: D15534188
Pulled By: houseroad
fbshipit-source-id: d174e778529def05ee666afeee4b8fb27786e320
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 3e1ba67..bce6b5e 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1747,6 +1747,94 @@
x = torch.randn(1, 2, 3)
self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_while(self):
+ class WhileModel(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ a = 0
+ while a < 4:
+ a += 1
+ return x + a
+
+ model = WhileModel()
+ inputs = torch.zeros(1, 2, 3, dtype=torch.long)
+ outputs = model(inputs)
+ self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
+ example_outputs=(outputs,))
+
+ def test_while_cond(self):
+ class WhileModel(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x, a):
+ b = (a < 4)
+ while b:
+ a += b.to(torch.long)
+ b = (a < 4)
+ return x + a
+
+ model = WhileModel()
+ x = torch.zeros(1, 2, 3, dtype=torch.long)
+ a = torch.tensor([0], dtype=torch.long)
+ outputs = model(x, a)
+ self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
+ example_outputs=(outputs,))
+
+ def test_loop(self):
+ class LoopModel(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(5):
+ x = x + i
+ return x
+
+ model = LoopModel()
+ inputs = torch.zeros(1, 2, 3, dtype=torch.long)
+ outputs = model(inputs)
+ self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
+ example_outputs=(outputs,))
+
+ def test_dynamic_loop(self):
+ class LoopModel(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(x.size(2)):
+ x = x + i
+ return x
+
+ model = LoopModel()
+ inputs = torch.zeros(1, 2, 3, dtype=torch.long)
+ outputs = model(inputs)
+ self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
+ example_outputs=(outputs,))
+
+ def test_nested_loops(self):
+ class NestedLoopsModel(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(5):
+ a = 0
+ while a < 4:
+ a += 1
+ for j in range(a):
+ x = x + j
+ x = x + a
+ return x
+
+ model = NestedLoopsModel()
+ inputs = torch.zeros(1, 2, 3, dtype=torch.long)
+ outputs = model(inputs)
+ self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
+ example_outputs=(outputs,))
+
+ def test_select(self):
+ class SelectModel(torch.nn.Module):
+ def forward(self, x):
+ return torch.select(x, 0, 1)
+
+ model = SelectModel()
+ inputs = torch.randn(3, 2, 1)
+ self.run_model_test(model, train=False, input=(inputs, ), batch_size=BATCH_SIZE)
+
# a bit of metaprogramming to set up all the rnn tests
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index a77bd0c..d033079 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -181,6 +181,8 @@
return onnx::TensorProto_DataType_INT32;
case at::kLong:
return onnx::TensorProto_DataType_INT64;
+ case at::kBool:
+ return onnx::TensorProto_DataType_BOOL;
default:
AT_ERROR("unexpected tensor scalar type");
}
@@ -206,19 +208,20 @@
onnx::ValueInfoProto* v,
const Value* n) {
v->set_name(n->uniqueName());
- onnx::TypeProto* t = v->mutable_type();
- onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
-
- onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
+ onnx::TypeProto* t = v->mutable_type();
+ onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
+ onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
const std::vector<std::int64_t>& sizes = node_type->sizes();
for (size_t i = 0; i < sizes.size(); i++) {
shape->add_dim();
shape->mutable_dim(i)->set_dim_value(sizes[i]);
}
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
- } else {
- tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
+ } else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
+ onnx::TypeProto* t = v->mutable_type();
+ onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
+ tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
}
}
diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
index b6594a4..5e6641c 100644
--- a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
+++ b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
@@ -3,12 +3,62 @@
namespace torch {
namespace jit {
+namespace onnx{
+using namespace ::c10::onnx;
+}
+
+Node* CreateCastToBoolNode(Value* val, Graph* graph) {
+ Node* cast_node = graph->create(onnx::Cast);
+ cast_node->addInput(val);
+ cast_node->i_(attr::to, /*Bool*/9);
+ return cast_node;
+}
+
+Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
+ // prev: cond_val -> consumer_node
+ // after: cond_val -> cast -> consumer_node
+ // NOTE: The cast is required because operators like PyTorch Greater/Less
+ // return tensor in type torch.uint8. However the type for condition
+ // input in ONNX Loop must be bool.
+ Node* cast_node = CreateCastToBoolNode(cond_val, graph);
+ cast_node->insertBefore(consumer_node);
+
+ consumer_node->replaceInputWith(cond_val, cast_node->output());
+ return cast_node;
+}
+
+bool IsCondCastRequired(Value* cond_val) {
+ const auto& type = cond_val->type();
+ if (type->isSubclass(TypeKind::DimensionedTensorType)) {
+ return type->expect<DimensionedTensorType>()->scalarType() != c10::kBool;
+ }
+ return !type->isSubclass(TypeKind::BoolType);
+}
+
void FixupONNXLoops(Block* block) {
for (auto* node : block->nodes()) {
if (node->kind() == ::c10::onnx::Loop) {
- AT_ASSERT(node->blocks().size() == 1);
- auto* sub_block = node->blocks()[0];
- sub_block->insertInput(1, "cond");
+ auto* loop_node = node;
+ auto* graph = loop_node->owningGraph();
+
+ // add cast to condition input outside the loop.
+ Value* cond_val = loop_node->inputs()[1];
+ if (IsCondCastRequired(cond_val))
+ InsertCastForCond(cond_val, graph, loop_node);
+
+ // Setup Loop input cond and i.
+ TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
+ auto* sub_block = loop_node->blocks()[0];
+ Value* cond = sub_block->insertInput(1, "cond");
+ cond->setType(BoolType::create());
+
+ Value* i = sub_block->inputs()[0];
+ i->setType(CompleteTensorType::fromNumberType(IntType::get()));
+
+ // add cast to condition input inside the loop.
+ Value* next_cond_val = sub_block->outputs()[0];
+ if (IsCondCastRequired(next_cond_val))
+ InsertCastForCond(next_cond_val, graph, sub_block->return_node());
}
for (Block* block : node->blocks()) {
FixupONNXLoops(block);