Remove flag to toggle CPU fusion in the presence of parallelism (#63514)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63514
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D30417127
Pulled By: bertmaher
fbshipit-source-id: b77d7c68364f2af73570740540f3b1152313016e
diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp
index 8dd6164..91fb4c2 100644
--- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp
+++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp
@@ -15,19 +15,15 @@
using namespace torch::jit::tensorexpr;
struct WithCPUFuser {
- WithCPUFuser(bool val = true)
- : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+ WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
overrideCanFuseOnCPU(val);
- setTexprParallelCPUEnabled(true);
}
~WithCPUFuser() {
overrideCanFuseOnCPU(cpuFuserEnabled);
- setTexprParallelCPUEnabled(parallel);
}
bool cpuFuserEnabled;
- bool parallel;
};
TEST(TEFuserPass, FuserPass_1) {
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index aa8be05..b9ed9d0 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -29,8 +29,6 @@
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
def tearDown(self):
torch._C._jit_set_profiling_executor(self.prev_exec)
@@ -42,7 +40,6 @@
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def test_tensor_type_not_determined_by_inputs(self):
@torch.jit.script
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 64c26b7..614226f 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -85,10 +85,6 @@
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
self.int_dtypes = [
torch.int8,
@@ -116,7 +112,6 @@
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index 6353113..47c7e68 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -24,9 +24,6 @@
torch._C._debug_set_fusion_group_inlining(False)
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
- # TODO: CPU fuser currently is disabled when multithreading.
- self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
- torch._C._jit_set_texpr_parallel_cpu_enabled(True)
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
@@ -39,7 +36,6 @@
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
- torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
def assertLastGraphAllFused(self):
self.assertAllFused(torch.jit.last_executed_optimized_graph())
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index d4add03..52bf453 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -1,6 +1,5 @@
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
-#include <ATen/Parallel.h>
#include <ATen/core/interned_strings.h>
#include <ATen/record_function.h>
#include <c10/util/FunctionRef.h>
@@ -250,15 +249,6 @@
} // namespace tensorexpr
static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
- return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
- texpr_parallel_cpu_enabled = val;
-}
void setTensorExprFuserEnabled(bool val) {
texpr_fuser_enabled_ = val;
@@ -898,14 +888,7 @@
return false;
}
if (device->is_cpu()) {
- // CPU fusion is only supported for single-thread.
- if (!canFuseOnCPU()) {
- return false;
- }
- if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
- return true;
- }
- return false;
+ return canFuseOnCPU();
} else if (device->is_cuda()) {
return canFuseOnGPU();
} else if (device->is_xpu()) {
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h
index 3f6538b..254aebd 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.h
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.h
@@ -24,8 +24,6 @@
TORCH_API bool tensorExprFuserEnabled();
TORCH_API bool setTexprReductionsEnabled(bool value);
TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
TORCH_API void RemoveProfileNodesAndSpecializeTypes(
std::shared_ptr<Graph>& graph);
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 5fca575..992e60e 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -711,8 +711,6 @@
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
- .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
- .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
.def(
"_jit_set_te_generate_block_code",
[](bool gen_block_code) {