return aten::gt to the list of fusable operations, add expected graphs (#11150)

Summary:
Fixes one of #11118 issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11150

Differential Revision: D9861372

Pulled By: apaszke

fbshipit-source-id: 98b196b89e991d3936360b30568360367fd32e8b
diff --git a/test/test_jit.py b/test/test_jit.py
index b715045..6612a8f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -365,6 +365,10 @@
 
         return ge
 
+    def assertAllFused(self, graph):
+        self.assertTrue(all(node.kind() in {'prim::Constant', 'prim::FusionGroup'} for node in graph.nodes()))
+        self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
+
     def assertExportImport(self, trace, inputs):
         graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
         m = torch.jit.ScriptModule()
@@ -766,6 +770,7 @@
         y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
         ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -782,6 +787,24 @@
         y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
         ge = self.checkTrace(f, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_comparison_eq_ne(self):
+        def f(x, y):
+            mask = (x == 0).type_as(x)
+            z = x * mask + y
+            mask = (x != 0).type_as(x)
+            z = z * mask + y
+            return z
+
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+
+        ge = self.checkTrace(f, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
     @staticmethod
     def fn_test_relu(x, y):
diff --git a/torch/csrc/jit/fusers/cuda/fused_kernel.cpp b/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
index 9062430..4067e6c 100644
--- a/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
+++ b/torch/csrc/jit/fusers/cuda/fused_kernel.cpp
@@ -65,7 +65,13 @@
   TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
   ptx.resize(ptx_size);
   TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data()));
-
+  CUcontext pctx = 0;
+  TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
+  if (!pctx) {
+     std::unique_lock<std::mutex> cudaFreeMutexLock(
+     *(THCCachingAllocator_getCudaFreeMutex()));
+     cudaFree(0);
+  }
   TORCH_CU_CHECK(cuModuleLoadData(&module, ptx.data()));
   TORCH_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str()));
 
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index e812369..4d69ed5 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -228,6 +228,9 @@
         node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::le(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
         node->matches("aten::le(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||
+        node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
+        node->matches("aten::gt(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
+        node->matches("aten::gt(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||
         node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
         node->matches("aten::ge(Tensor self, Scalar other) -> Tensor", /*const=*/attr::other) ||
         node->matches("aten::ge(Scalar other, Tensor self) -> Tensor", /*const=*/attr::other) ||