Add support to compare devices (#53045)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53045

Test Plan:
=====
python test/test_jit.py -k test_device_not_equal

Reviewed By: pbelevich

Differential Revision: D26737964

Pulled By: nikithamalgifb

fbshipit-source-id: 2205aa1f214a86282602168c364dca1363d2f7dd
diff --git a/test/test_jit.py b/test/test_jit.py
index 2c9653c..677b4bf 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1688,6 +1688,18 @@
         self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
         self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
 
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    def test_device_not_equal(self):
+
+        def compare_device(x: torch.device):
+            return x != torch.device("cuda:0")
+
+        def compare_two_device(x: torch.device, y: torch.device):
+            return x != y
+
+        self.checkScript(compare_device, (torch.device("cuda:0"),))
+        self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), ))
+
     def test_tuple_specialization(self):
         @torch.jit.script
         def f(t, s):
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 39981e1..a548c22 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -637,6 +637,14 @@
          },
          aliasAnalysisFromSchema()),
      OperatorGenerator(
+         TORCH_SELECTIVE_SCHEMA("aten::ne.device(Device a, Device b) -> bool"),
+         [](Stack* stack) {
+           auto a = pop(stack).toDevice();
+           auto b = pop(stack).toDevice();
+           push(stack, a != b);
+         },
+         aliasAnalysisFromSchema()),
+     OperatorGenerator(
          TORCH_SELECTIVE_SCHEMA("aten::eq.bool(bool a, bool b) -> bool"),
          [](Stack* stack) {
            auto a = pop(stack);