add topology index check in Graph::lint() (#13037)
Summary:
just a sanity check to make sure everything is in order
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13037
Differential Revision: D10854563
Pulled By: michaelsuo
fbshipit-source-id: 409303c4cbf058b75e24bf2213b49e9d79cb862e
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index b4e77a9..32f9bbe 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -389,6 +389,14 @@
n->lint();
}
void check_block(const Block *b) {
+ // Check topological ordering
+ JIT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
+ auto curNode = *b->nodes().begin();
+ while (curNode != b->return_node()) {
+ JIT_ASSERT(curNode->isBefore(curNode->next()));
+ curNode = curNode->next();
+ }
+
for (auto input : b->inputs()) {
check_value(input);
JIT_ASSERT(input->node()->kind_ == prim::Param);
@@ -835,14 +843,14 @@
return outputs_.at(i);
}
-bool Node::isBefore(Node * n) const {
+bool Node::isBefore(const Node * n) const {
if (this == n) {
return false;
}
return !isAfter(n);
}
-bool Node::isAfter(Node * n) const {
+bool Node::isAfter(const Node * n) const {
JIT_ASSERT(this->owningBlock() == n->owningBlock());
return this->topo_position_ > n->topo_position_;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index d96ff17..8c02deb 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -422,10 +422,10 @@
}
// Is 'this' before 'n' in the topological order?
- TORCH_API bool isBefore(Node * n) const;
+ TORCH_API bool isBefore(const Node * n) const;
// Is 'this' after 'n' in the topological order?
- TORCH_API bool isAfter(Node * n) const;
+ TORCH_API bool isAfter(const Node * n) const;
// Insert unattached 'this' node before 'n' in the topological order.
// Returns this (for chaining).