[nnc] Enable CPU fusion (#63545)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63545

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30417370

Pulled By: bertmaher

fbshipit-source-id: 84ce7a578a3678d5562bab99d1dc00330c4f72d1
diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp
index ec67c4b..ef7e9e0 100644
--- a/torch/csrc/jit/codegen/fuser/interface.cpp
+++ b/torch/csrc/jit/codegen/fuser/interface.cpp
@@ -8,15 +8,12 @@
 #include <c10/util/Flags.h>
 #include <stdexcept>
 
-C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion");
-
 namespace torch {
 namespace jit {
 
 namespace detail {
 
-// Note: CPU fusion is currently disabled due to test flakiness
-#if defined(FBCODE_CAFFE2)
+#ifdef TORCH_ENABLE_LLVM
 bool cpu_fuser_enabled = true;
 #else
 bool cpu_fuser_enabled = false;
@@ -37,8 +34,7 @@
 }
 
 bool canFuseOnCPU() {
-  return fuser::hasFusionBackend(DeviceType::CPU) &&
-      (detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion);
+  return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
 }
 
 bool canFuseOnGPU() {
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index f7dd466..653f9fe 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -183,7 +183,7 @@
       return !strict_fuser_check;
     }
     if ((*device).is_cpu()) {
-      return canFuseOnCPU();
+      return canFuseOnCPULegacy();
     } else if ((*device).is_cuda()) {
       return canFuseOnGPU();
     } else if ((*device).is_xpu()) {
@@ -1244,6 +1244,16 @@
 
 } // anonymous namespace
 
+static bool cpu_fuser_enabled_legacy = false;
+
+bool canFuseOnCPULegacy() {
+  return cpu_fuser_enabled_legacy;
+}
+
+void overrideCanFuseOnCPULegacy(bool value) {
+  cpu_fuser_enabled_legacy = value;
+}
+
 void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
   AliasDb db(graph);
   GraphFuser(&db, graph->block(), strict_fuser_check).run();
diff --git a/torch/csrc/jit/passes/graph_fuser.h b/torch/csrc/jit/passes/graph_fuser.h
index 0cdcc2e..d710e5a 100644
--- a/torch/csrc/jit/passes/graph_fuser.h
+++ b/torch/csrc/jit/passes/graph_fuser.h
@@ -5,6 +5,9 @@
 namespace torch {
 namespace jit {
 
+TORCH_API bool canFuseOnCPULegacy();
+TORCH_API void overideCanFuseOnCPULegacy(bool value);
+
 // NB: Be sure to run DCE before fusion, because dead instructions
 // can prevent fusion opportunities from being exploited.
 // On Windows will noop, NYI
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 992e60e..f5da7b3 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -589,6 +589,8 @@
       .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
       .def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
       .def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
+      .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
+      .def("_jit_override_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
       .def(
           "_jit_differentiate",
           [](Graph& g) {