return aten::gt to the list of fusable operations, add expected graphs (#11150)
Summary:
Fixes one of #11118 issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11150
Differential Revision: D9861372
Pulled By: apaszke
fbshipit-source-id: 98b196b89e991d3936360b30568360367fd32e8b
diff --git a/test/test_jit.py b/test/test_jit.py
index b715045..6612a8f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -365,6 +365,10 @@
return ge
+ def assertAllFused(self, graph):
+ self.assertTrue(all(node.kind() in {'prim::Constant', 'prim::FusionGroup'} for node in graph.nodes()))
+ self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
+
def assertExportImport(self, trace, inputs):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
m = torch.jit.ScriptModule()
@@ -766,6 +770,7 @@
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -782,6 +787,24 @@
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
ge = self.checkTrace(f, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_comparison_eq_ne(self):
+ def f(x, y):
+ mask = (x == 0).type_as(x)
+ z = x * mask + y
+ mask = (x != 0).type_as(x)
+ z = z * mask + y
+ return z
+
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+
+ ge = self.checkTrace(f, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@staticmethod
def fn_test_relu(x, y):
diff --git a/torch/csrc/jit/fusers/cuda/fused_kernel.cpp b/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
index 9062430..4067e6c 100644
--- a/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
+++ b/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
@@ -65,7 +65,13 @@
TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
ptx.resize(ptx_size);
TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data()));
-
+ CUcontext pctx = 0;
+ TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
+ if (!pctx) {
+ std::unique_lock<std::mutex> cudaFreeMutexLock(
+ *(THCCachingAllocator_getCudaFreeMutex()));
+ cudaFree(0);
+ }
TORCH_CU_CHECK(cuModuleLoadData(&module, ptx.data()));
TORCH_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str()));
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index e812369..4d69ed5 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -228,6 +228,9 @@
node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
node->matches("aten::le(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
node->matches("aten::le(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||
+ node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
+ node->matches("aten::gt(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
+ node->matches("aten::gt(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||
node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
node->matches("aten::ge(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
node->matches("aten::ge(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||