Remove flag to toggle CPU fusion in the presence of parallelism (#63514)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63514

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30417127

Pulled By: bertmaher

fbshipit-source-id: b77d7c68364f2af73570740540f3b1152313016e
diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp
index 8dd6164..91fb4c2 100644
--- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp
+++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp
@@ -15,19 +15,15 @@
 using namespace torch::jit::tensorexpr;
 
 struct WithCPUFuser {
-  WithCPUFuser(bool val = true)
-      : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+  WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
     overrideCanFuseOnCPU(val);
-    setTexprParallelCPUEnabled(true);
   }
 
   ~WithCPUFuser() {
     overrideCanFuseOnCPU(cpuFuserEnabled);
-    setTexprParallelCPUEnabled(parallel);
   }
 
   bool cpuFuserEnabled;
-  bool parallel;
 };
 
 TEST(TEFuserPass, FuserPass_1) {
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index aa8be05..b9ed9d0 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -29,8 +29,6 @@
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
     def tearDown(self):
         torch._C._jit_set_profiling_executor(self.prev_exec)
@@ -42,7 +40,6 @@
         torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def test_tensor_type_not_determined_by_inputs(self):
         @torch.jit.script
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 64c26b7..614226f 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -85,10 +85,6 @@
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
 
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         self.int_dtypes = [
             torch.int8,
@@ -116,7 +112,6 @@
 
         torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index 6353113..47c7e68 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -24,9 +24,6 @@
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
 
@@ -39,7 +36,6 @@
         torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index d4add03..52bf453 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
 
-#include <ATen/Parallel.h>
 #include <ATen/core/interned_strings.h>
 #include <ATen/record_function.h>
 #include <c10/util/FunctionRef.h>
@@ -250,15 +249,6 @@
 } // namespace tensorexpr
 
 static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
-  return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
-  texpr_parallel_cpu_enabled = val;
-}
 
 void setTensorExprFuserEnabled(bool val) {
   texpr_fuser_enabled_ = val;
@@ -898,14 +888,7 @@
       return false;
     }
     if (device->is_cpu()) {
-      // CPU fusion is only supported for single-thread.
-      if (!canFuseOnCPU()) {
-        return false;
-      }
-      if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
-        return true;
-      }
-      return false;
+      return canFuseOnCPU();
     } else if (device->is_cuda()) {
       return canFuseOnGPU();
     } else if (device->is_xpu()) {
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h
index 3f6538b..254aebd 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.h
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.h
@@ -24,8 +24,6 @@
 TORCH_API bool tensorExprFuserEnabled();
 TORCH_API bool setTexprReductionsEnabled(bool value);
 TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
 
 TORCH_API void RemoveProfileNodesAndSpecializeTypes(
     std::shared_ptr<Graph>& graph);
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 5fca575..992e60e 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -711,8 +711,6 @@
       .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
       .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
       .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
-      .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
-      .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
       .def(
           "_jit_set_te_generate_block_code",
           [](bool gen_block_code) {