Add bindings for .cpu() & .cuda() to script (#15904)
Summary:
Adding bindings for .cpu() and .cuda() to script.
It's worth noting that if the device remains unchanged, than the returned tensor aliases the input, but if it does change than they do not alias each other.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15904
Differential Revision: D13632879
Pulled By: eellison
fbshipit-source-id: 024a04f267909674aa1e510562efd9cb081f407c
diff --git a/test/test_jit.py b/test/test_jit.py
index 3cce42f..c6cbe30 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2914,6 +2914,25 @@
self.checkScript(to_device, (torch.ones(3, 4),))
+ def test_tensor_to_cpu(self):
+ def to_cpu(x):
+ return x.cpu()
+
+ x = torch.ones(3, 4)
+ script_fn = torch.jit.script(to_cpu)
+ self.assertEqual(to_cpu(x).device, script_fn(x).device)
+ self.checkScript(to_cpu, (x,))
+
+ @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
+ def test_tensor_to_cuda(self):
+ def to_cuda(x):
+ return x.cuda()
+
+ x = torch.ones(3, 4)
+ script_fn = torch.jit.script(to_cuda)
+ self.assertEqual(to_cuda(x).device, script_fn(x).device)
+ self.checkScript(to_cuda, (x,))
+
def test_generic_list_errors(self):
with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
@torch.jit.script
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index ee2ac40..e468fa94 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -315,6 +315,26 @@
};
}),
Operator(
+ "aten::cpu(Tensor(a) self) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.cpu());
+ return 0;
+ };
+ }),
+ Operator(
+ "aten::cuda(Tensor(a) self) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.cuda());
+ return 0;
+ };
+ }),
+ Operator(
"prim::Undefined() -> Tensor",
[](const Node* node) {
return [](Stack& stack) {