Use AT_WARN for warnings in the JIT (#14770)
Summary:
Previously their implementation dispatched to prim::Print, which kept
printing the warnings.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14770
Differential Revision: D13327629
Pulled By: suo
fbshipit-source-id: b9913f533d4530eb7c29146c39981ba7f72b6b68
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index ec0e045..78ad29c 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -73,6 +73,7 @@
_(prim, NoneGenerator) \
_(prim, MMTreeReduce) \
_(prim, MMBatchSide) \
+ _(aten, warn) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
_(prim, fork) \
diff --git a/test/expect/TestJit.test_warnings.expect b/test/expect/TestJit.test_warnings.expect
index 60fd27b..c4ab59d 100644
--- a/test/expect/TestJit.test_warnings.expect
+++ b/test/expect/TestJit.test_warnings.expect
@@ -5,7 +5,7 @@
%4 : bool = prim::TensorToBool(%3)
= prim::If(%4)
block0() {
- = prim::Print(%1)
+ = aten::warn(%1, %2)
-> ()
}
block1() {
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index 72b1501..fe02712 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -23,9 +23,9 @@
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
- if (node->isNondeterministic() || node->kind() == prim::PythonOp ||
- node->kind() == prim::Print || aliasDb.hasWriters(node) ||
- aliasDb.hasWildcard(node)) {
+ if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
+ node->kind() == aten::warn || node->isNondeterministic() ||
+ aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
}
diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp
index bd5d5a3..af6dfef 100644
--- a/torch/csrc/jit/passes/dead_code_elimination.cpp
+++ b/torch/csrc/jit/passes/dead_code_elimination.cpp
@@ -181,6 +181,7 @@
if (it != memo_.end())
return it->second;
bool has_side_effects = node->kind() == prim::Print ||
+ node->kind() == aten::warn ||
node->kind() == prim::RaiseException ||
node->kind() == prim::PythonOp ||
std::any_of(node->blocks().begin(),
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 7ed54c0..301ef80 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -330,6 +330,16 @@
};
}),
Operator(
+ FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}),
+ [](const Node* node) {
+ return [](Stack& stack) {
+ drop(stack, 1);
+ AT_WARN(pop(stack).toStringRef());
+ return 0;
+ };
+ }),
+
+ Operator(
"prim::RaiseException(str msg) -> ()",
[](const Node* node) -> Operation {
return [](Stack& stack) {
diff --git a/torch/csrc/jit/script/builtin_functions.cpp b/torch/csrc/jit/script/builtin_functions.cpp
index ca42e68..a6b35ee 100644
--- a/torch/csrc/jit/script/builtin_functions.cpp
+++ b/torch/csrc/jit/script/builtin_functions.cpp
@@ -28,16 +28,6 @@
return torch.reciprocal(b) * a
)SCRIPT");
-auto python_builtins_source = R"SCRIPT(
-def warn(string: str):
- print(string)
-)SCRIPT";
-
-auto python_builtins_source_overloads = R"SCRIPT(
-def warn(string: str, stacklevel: int):
- print(string)
-)SCRIPT";
-
auto _ntuple_ops = CodeTemplate(
R"SCRIPT(
def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
@@ -84,8 +74,6 @@
env.s("Scalar", scalar);
loadSource(scalar_operators_source.format(env));
}
- loadSource(python_builtins_source);
- loadSource(python_builtins_source_overloads);
using str_pair = std::pair<std::string, std::string>;
const std::vector<str_pair> name_len = {