centralize side effects ops as node method (#15188)
Summary:
A number of different passes rely on whether a node has side effects. This centralizes the list of side effectful ops in one place.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15188
Differential Revision: D13508438
Pulled By: eellison
fbshipit-source-id: 2143e782b787731ce007b6dcd50cbde30e1b8dd0
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 650d5ef..df038e3 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -686,6 +686,17 @@
return true;
}
+bool Node::hasSideEffects() const {
+ switch (kind_) {
+ case prim::PythonOp:
+ case prim::Print:
+ case prim::RaiseException:
+ case aten::warn:
+ return true;
+ }
+ return false;
+}
+
// Assign this node a topological position, to facilitate fast isBefore() and
// isAfter() queries. Must be called right after a node is inserted into the
// node list.
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 2a6c9cf..71a5361 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -353,6 +353,7 @@
}
TORCH_API bool isNondeterministic() const;
+ TORCH_API bool hasSideEffects () const;
// Graphs
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index b96cfaa..cac8f6b 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -23,8 +23,7 @@
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
- if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
- node->kind() == aten::warn || node->isNondeterministic() ||
+ if (node->hasSideEffects() || node->isNondeterministic() ||
aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index a1a6c1a..2446759 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -16,10 +16,6 @@
std::unordered_set<Symbol> skip_list = {
prim::If,
prim::Loop, //TODO: handle Loop
- prim::Print,
- prim::RaiseException,
- aten::warn,
- prim::PythonOp, //may have side effects
prim::Constant,
prim::Undefined,
prim::NoneGenerator,
@@ -125,7 +121,7 @@
return v->node()->kind() == prim::Constant;
});
bool supported_node = !n->kind().is_onnx() &&
- skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
+ skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() &&
!aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
auto run_blocks = [&]() {
if (recurse) {
diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp
index 0167d03..b7d606c 100644
--- a/torch/csrc/jit/passes/dead_code_elimination.cpp
+++ b/torch/csrc/jit/passes/dead_code_elimination.cpp
@@ -245,9 +245,7 @@
auto it = memo_.find(node);
if (it != memo_.end())
return it->second;
- bool has_side_effects = node->kind() == prim::Print ||
- node->kind() == aten::warn || node->kind() == prim::RaiseException ||
- node->kind() == prim::PythonOp ||
+ bool has_side_effects = node->hasSideEffects() ||
std::any_of(node->blocks().begin(),
node->blocks().end(),
[&](Block* b) {
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 2d9677a..85465d5 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -434,10 +434,6 @@
}
return;
}
- case prim::PythonOp:
- case prim::Print:
- case prim::RaiseException:
- case aten::warn:
case prim::Undefined: {
setUnshapedType(node);
return;
@@ -445,6 +441,11 @@
default:
break; // fall-through
}
+
+ if (node->hasSideEffects()) {
+ return;
+ }
+
if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")
|| node->kind() == prim::FusedConcat) {
return PropagateCatShape(node);