[JIT] Retry - Support scripting torch.is_autocast_enabled() (#82394)

This adds an `aten::is_autocast_enabled` op into the jit runtime so that
autocasting ops can be scripted and called from within jit.

Differential Revision: [D38294040](https://our.internmc.facebook.com/intern/diff/D38294040)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82394
Approved by: https://github.com/eellison
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index abcedbb..dc5860e 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -222,6 +222,8 @@
   _(cuda, _current_device)           \
   _(cuda, synchronize)               \
   _(aten, has_torch_function)        \
+  _(aten, is_autocast_enabled)       \
+  _(aten, is_autocast_cpu_enabled)   \
   FORALL_ATEN_BASE_SYMBOLS(_)        \
   _(onnx, Add)                       \
   _(onnx, Concat)                    \
diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py
index ae0a9cd..5d555e7 100644
--- a/test/test_jit_autocast.py
+++ b/test/test_jit_autocast.py
@@ -2,7 +2,7 @@
 
 import torch
 from torch.cuda.amp import autocast
-from typing import Optional
+from typing import Optional, Tuple
 
 import unittest
 from test_jit import JitTestCase
@@ -819,5 +819,101 @@
                 continue
             test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
 
+    def test_script_autocast_cpu(self):
+        def fn(x):
+            if torch.is_autocast_cpu_enabled():
+                return x.relu()
+            else:
+                return x.sin()
+
+        fn_s = torch.jit.script(fn)
+
+        x = torch.rand((4, 4)) - 0.5
+        with torch.cpu.amp.autocast():
+            self.assertEqual(fn_s(x), fn(x))
+
+        with torch.cpu.amp.autocast(enabled=True):
+            self.assertEqual(fn_s(x), fn(x))
+
+        self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
+
+    @unittest.skipIf(not TEST_CUDA, "No cuda")
+    def test_script_autocast_cuda(self):
+        def fn(x):
+            if torch.is_autocast_enabled():
+                return x.relu()
+            else:
+                return x.sin()
+
+        fn_s = torch.jit.script(fn)
+
+        x = torch.rand((4, 4)) - 0.5
+        with torch.cpu.amp.autocast():
+            self.assertEqual(fn_s(x), fn(x))
+
+        with torch.cuda.amp.autocast(enabled=True):
+            self.assertEqual(fn_s(x), fn(x))
+
+        self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
+
+
+    def test_scripted_aliasing(self):
+        # torch.is_autocast_enabled should not be able to move inside of the autocast context.
+        def fn(x):
+            if torch.is_autocast_enabled():
+                y = True
+            else:
+                y = False
+            with torch.cuda.amp.autocast(enabled=True):
+                z = x.relu()
+            return y, z
+
+        fn_s = torch.jit.script(fn)
+        graph = fn_s.graph
+
+        aliasdb = graph.alias_db()
+
+        is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
+        enter_nodes = graph.findAllNodes("prim::Enter")
+
+        self.assertEqual(len(is_enabled_nodes), 1)
+        self.assertEqual(len(enter_nodes), 1)
+
+        self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
+
+
+    def test_script_autocast_enable_and_check(self):
+        def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
+            b1 = torch.is_autocast_cpu_enabled()
+            v1 = torch.mm(x, y)
+            with torch.cpu.amp.autocast(enabled=True):
+                b2 = torch.is_autocast_cpu_enabled()
+                v2 = torch.mm(x, y)
+                with torch.cpu.amp.autocast(enabled=False):
+                    b3 = torch.is_autocast_cpu_enabled()
+                    v3 = torch.mm(x, y)
+            return (v1, b1, v2, b2, v3, b3)
+
+        # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
+        def check_fn_results(arr):
+            [v1, b1, v2, b2, v3, b3] = arr
+            self.assertTrue((v1.dtype == torch.float) != b1)
+            self.assertTrue((v2.dtype == torch.float) != b2)
+            self.assertTrue((v3.dtype == torch.float) != b3)
+
+        x = torch.rand((2, 2), dtype=torch.float)
+        y = torch.rand((2, 2), dtype=torch.float)
+
+        fn_s = torch.jit.script(fn)
+
+        with torch.cpu.amp.autocast(enabled=False):
+            check_fn_results(fn(x, y))
+            check_fn_results(fn_s(x, y))
+
+        with torch.cpu.amp.autocast(enabled=True):
+            check_fn_results(fn(x, y))
+            check_fn_results(fn_s(x, y))
+
+
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp
index 07c1651..4ac023b 100644
--- a/torch/csrc/jit/passes/autocast.cpp
+++ b/torch/csrc/jit/passes/autocast.cpp
@@ -216,6 +216,39 @@
   }
 }
 
+// Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
+// determine whether autocasting is enabled. With JIT-scripted functions, we
+// actually need to return true if eager autocast OR jit autocast are enabled.
+//
+// In the case where JIT autocast is enabled, we replace
+//    %x : bool = aten::is_autocast_enabled()
+// with a constant "True".
+//
+// More context on eager vs JIT autocasting:
+//
+// Autocasting actually has two settings: eager autocasting, and JIT
+// autocasting. Eager autocasting is the thread-local setting that turns on
+// the relevant bit in the dispatcher settings. JIT autocasting is the pass
+// implemented in this file, which makes changes to the graph to insert casting
+// ops in order to achieve the same behavior as eager autocasting.
+//
+// If eager autocasting is enabled at the time when a JIT-scripted function is
+// invoked, then autocasting will occur regardless of what the JIT-autocasting
+// settings are.
+void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
+  if (!is_jit_enabled) {
+    return;
+  }
+
+  auto graph = node->owningGraph();
+
+  WithInsertPoint insert_point(node);
+
+  Value* true_constant = graph->insertConstant(IValue(true));
+  node->output()->replaceAllUsesWith(true_constant);
+  node->destroy();
+}
+
 // [Note: implicit type promotion in Autocast]
 //
 // Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
@@ -319,6 +352,14 @@
         }
         break;
 
+      case aten::is_autocast_enabled:
+        updateAutocastEnabledCheck(node, current_state().gpu_enabled);
+        break;
+
+      case aten::is_autocast_cpu_enabled:
+        updateAutocastEnabledCheck(node, current_state().cpu_enabled);
+        break;
+
       // CastPolicy::fp16 (cast all inputs to float16)
       case aten::_convolution:
       case aten::conv1d:
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index 2fbe319..68690a2 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -203,7 +203,17 @@
       .def(
           "has_writers",
           [&](AliasDb& db, Value* v1) { return db.hasWriters(v1); })
-      .def("__str__", &AliasDb::toString);
+      .def("__str__", &AliasDb::toString)
+      .def(
+          "move_after_topologically_valid",
+          [](AliasDb& db, Node* n, Node* movePoint) {
+            return db.moveAfterTopologicallyValid(n, movePoint);
+          })
+      .def(
+          "move_before_topologically_valid",
+          [](AliasDb& db, Node* n, Node* movePoint) {
+            return db.moveBeforeTopologicallyValid(n, movePoint);
+          });
 #define GS(name) def(#name, &Graph ::name)
   py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
       .def(py::init<>())
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 7ca9157..d360a78 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1,3 +1,4 @@
+#include <ATen/autocast_mode.h>
 #include <c10/util/Optional.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
@@ -661,6 +662,28 @@
         },
         aliasAnalysisFromSchema()),
     OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA("aten::is_autocast_enabled() -> bool"),
+        [](Stack& stack) {
+#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
+          bool enabled = false;
+#else
+          bool enabled = at::autocast::is_enabled();
+#endif
+          push(stack, enabled);
+        },
+        aliasAnalysisConservative()),
+    OperatorGeneratorArgs(
+        TORCH_SELECTIVE_SCHEMA("aten::is_autocast_cpu_enabled() -> bool"),
+        [](Stack& stack) {
+#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
+          bool enabled = false;
+#else
+          bool enabled = at::autocast::is_cpu_enabled();
+#endif
+          push(stack, enabled);
+        },
+        aliasAnalysisConservative()),
+    OperatorGeneratorArgs(
         TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
         unInitialized,
         aliasAnalysisSpecialCase()),