[JIT] Retry - Support scripting torch.is_autocast_enabled() (#82394)
This adds an `aten::is_autocast_enabled` op into the jit runtime so that
autocasting ops can be scripted and called from within jit.
Differential Revision: [D38294040](https://our.internmc.facebook.com/intern/diff/D38294040)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82394
Approved by: https://github.com/eellison
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index abcedbb..dc5860e 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -222,6 +222,8 @@
_(cuda, _current_device) \
_(cuda, synchronize) \
_(aten, has_torch_function) \
+ _(aten, is_autocast_enabled) \
+ _(aten, is_autocast_cpu_enabled) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py
index ae0a9cd..5d555e7 100644
--- a/test/test_jit_autocast.py
+++ b/test/test_jit_autocast.py
@@ -2,7 +2,7 @@
import torch
from torch.cuda.amp import autocast
-from typing import Optional
+from typing import Optional, Tuple
import unittest
from test_jit import JitTestCase
@@ -819,5 +819,101 @@
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
+ def test_script_autocast_cpu(self):
+ def fn(x):
+ if torch.is_autocast_cpu_enabled():
+ return x.relu()
+ else:
+ return x.sin()
+
+ fn_s = torch.jit.script(fn)
+
+ x = torch.rand((4, 4)) - 0.5
+ with torch.cpu.amp.autocast():
+ self.assertEqual(fn_s(x), fn(x))
+
+ with torch.cpu.amp.autocast(enabled=True):
+ self.assertEqual(fn_s(x), fn(x))
+
+ self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
+
+ @unittest.skipIf(not TEST_CUDA, "No cuda")
+ def test_script_autocast_cuda(self):
+ def fn(x):
+ if torch.is_autocast_enabled():
+ return x.relu()
+ else:
+ return x.sin()
+
+ fn_s = torch.jit.script(fn)
+
+ x = torch.rand((4, 4)) - 0.5
+ with torch.cpu.amp.autocast():
+ self.assertEqual(fn_s(x), fn(x))
+
+ with torch.cuda.amp.autocast(enabled=True):
+ self.assertEqual(fn_s(x), fn(x))
+
+ self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
+
+
+ def test_scripted_aliasing(self):
+ # torch.is_autocast_enabled should not be able to move inside of the autocast context.
+ def fn(x):
+ if torch.is_autocast_enabled():
+ y = True
+ else:
+ y = False
+ with torch.cuda.amp.autocast(enabled=True):
+ z = x.relu()
+ return y, z
+
+ fn_s = torch.jit.script(fn)
+ graph = fn_s.graph
+
+ aliasdb = graph.alias_db()
+
+ is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
+ enter_nodes = graph.findAllNodes("prim::Enter")
+
+ self.assertEqual(len(is_enabled_nodes), 1)
+ self.assertEqual(len(enter_nodes), 1)
+
+ self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
+
+
+ def test_script_autocast_enable_and_check(self):
+ def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
+ b1 = torch.is_autocast_cpu_enabled()
+ v1 = torch.mm(x, y)
+ with torch.cpu.amp.autocast(enabled=True):
+ b2 = torch.is_autocast_cpu_enabled()
+ v2 = torch.mm(x, y)
+ with torch.cpu.amp.autocast(enabled=False):
+ b3 = torch.is_autocast_cpu_enabled()
+ v3 = torch.mm(x, y)
+ return (v1, b1, v2, b2, v3, b3)
+
+ # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
+ def check_fn_results(arr):
+ [v1, b1, v2, b2, v3, b3] = arr
+ self.assertTrue((v1.dtype == torch.float) != b1)
+ self.assertTrue((v2.dtype == torch.float) != b2)
+ self.assertTrue((v3.dtype == torch.float) != b3)
+
+ x = torch.rand((2, 2), dtype=torch.float)
+ y = torch.rand((2, 2), dtype=torch.float)
+
+ fn_s = torch.jit.script(fn)
+
+ with torch.cpu.amp.autocast(enabled=False):
+ check_fn_results(fn(x, y))
+ check_fn_results(fn_s(x, y))
+
+ with torch.cpu.amp.autocast(enabled=True):
+ check_fn_results(fn(x, y))
+ check_fn_results(fn_s(x, y))
+
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp
index 07c1651..4ac023b 100644
--- a/torch/csrc/jit/passes/autocast.cpp
+++ b/torch/csrc/jit/passes/autocast.cpp
@@ -216,6 +216,39 @@
}
}
+// Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
+// determine whether autocasting is enabled. With JIT-scripted functions, we
+// actually need to return true if eager autocast OR jit autocast are enabled.
+//
+// In the case where JIT autocast is enabled, we replace
+// %x : bool = aten::is_autocast_enabled()
+// with a constant "True".
+//
+// More context on eager vs JIT autocasting:
+//
+// Autocasting actually has two settings: eager autocasting, and JIT
+// autocasting. Eager autocasting is the thread-local setting that turns on
+// the relevant bit in the dispatcher settings. JIT autocasting is the pass
+// implemented in this file, which makes changes to the graph to insert casting
+// ops in order to achieve the same behavior as eager autocasting.
+//
+// If eager autocasting is enabled at the time when a JIT-scripted function is
+// invoked, then autocasting will occur regardless of what the JIT-autocasting
+// settings are.
+void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
+ if (!is_jit_enabled) {
+ return;
+ }
+
+ auto graph = node->owningGraph();
+
+ WithInsertPoint insert_point(node);
+
+ Value* true_constant = graph->insertConstant(IValue(true));
+ node->output()->replaceAllUsesWith(true_constant);
+ node->destroy();
+}
+
// [Note: implicit type promotion in Autocast]
//
// Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
@@ -319,6 +352,14 @@
}
break;
+ case aten::is_autocast_enabled:
+ updateAutocastEnabledCheck(node, current_state().gpu_enabled);
+ break;
+
+ case aten::is_autocast_cpu_enabled:
+ updateAutocastEnabledCheck(node, current_state().cpu_enabled);
+ break;
+
// CastPolicy::fp16 (cast all inputs to float16)
case aten::_convolution:
case aten::conv1d:
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index 2fbe319..68690a2 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -203,7 +203,17 @@
.def(
"has_writers",
[&](AliasDb& db, Value* v1) { return db.hasWriters(v1); })
- .def("__str__", &AliasDb::toString);
+ .def("__str__", &AliasDb::toString)
+ .def(
+ "move_after_topologically_valid",
+ [](AliasDb& db, Node* n, Node* movePoint) {
+ return db.moveAfterTopologicallyValid(n, movePoint);
+ })
+ .def(
+ "move_before_topologically_valid",
+ [](AliasDb& db, Node* n, Node* movePoint) {
+ return db.moveBeforeTopologicallyValid(n, movePoint);
+ });
#define GS(name) def(#name, &Graph ::name)
py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
.def(py::init<>())
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 7ca9157..d360a78 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1,3 +1,4 @@
+#include <ATen/autocast_mode.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
@@ -661,6 +662,28 @@
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
+ TORCH_SELECTIVE_SCHEMA("aten::is_autocast_enabled() -> bool"),
+ [](Stack& stack) {
+#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
+ bool enabled = false;
+#else
+ bool enabled = at::autocast::is_enabled();
+#endif
+ push(stack, enabled);
+ },
+ aliasAnalysisConservative()),
+ OperatorGeneratorArgs(
+ TORCH_SELECTIVE_SCHEMA("aten::is_autocast_cpu_enabled() -> bool"),
+ [](Stack& stack) {
+#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
+ bool enabled = false;
+#else
+ bool enabled = at::autocast::is_cpu_enabled();
+#endif
+ push(stack, enabled);
+ },
+ aliasAnalysisConservative()),
+ OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
unInitialized,
aliasAnalysisSpecialCase()),