Add support to compare devices (#53045)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53045
Test Plan:
=====
python test/test_jit.py -k test_device_not_equal
Reviewed By: pbelevich
Differential Revision: D26737964
Pulled By: nikithamalgifb
fbshipit-source-id: 2205aa1f214a86282602168c364dca1363d2f7dd
diff --git a/test/test_jit.py b/test/test_jit.py
index 2c9653c..677b4bf 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1688,6 +1688,18 @@
self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ def test_device_not_equal(self):
+
+ def compare_device(x: torch.device):
+ return x != torch.device("cuda:0")
+
+ def compare_two_device(x: torch.device, y: torch.device):
+ return x != y
+
+ self.checkScript(compare_device, (torch.device("cuda:0"),))
+ self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), ))
+
def test_tuple_specialization(self):
@torch.jit.script
def f(t, s):
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 39981e1..a548c22 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -637,6 +637,14 @@
},
aliasAnalysisFromSchema()),
OperatorGenerator(
+ TORCH_SELECTIVE_SCHEMA("aten::ne.device(Device a, Device b) -> bool"),
+ [](Stack* stack) {
+ auto a = pop(stack).toDevice();
+ auto b = pop(stack).toDevice();
+ push(stack, a != b);
+ },
+ aliasAnalysisFromSchema()),
+ OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("aten::eq.bool(bool a, bool b) -> bool"),
[](Stack* stack) {
auto a = pop(stack);