[torchscript] Handle prim::device and prim::dtype (#127466)
- Support prim::device and prim::dtype during torchscript migration to export
- Add unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127466
Approved by: https://github.com/SherlockNoMad
diff --git a/test/export/test_converter.py b/test/export/test_converter.py
index b6d0e54..64cea8c 100644
--- a/test/export/test_converter.py
+++ b/test/export/test_converter.py
@@ -1,5 +1,7 @@
# Owner(s): ["oncall: export"]
+import unittest
+
import torch
import torch.utils._pytree as pytree
@@ -9,6 +11,8 @@
from torch.testing._internal.common_utils import run_tests
+requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
+
class TestConverter(TestCase):
def _check_equal_ts_ep_converter(self, mod, inp):
@@ -64,6 +68,46 @@
self._check_equal_ts_ep_converter(MOutputTuple(), inp)
self._check_equal_ts_ep_converter(MOutputDict(), inp)
+ def test_prim_device(self):
+ class Module(torch.nn.Module):
+ def forward(self, x):
+ device = x.device
+ return torch.ones(2, 3, device=device)
+
+ inp = (torch.rand(3, 4),)
+ self._check_equal_ts_ep_converter(Module(), inp)
+
+ @requires_cuda
+ def test_prim_device_cuda(self):
+ class Module(torch.nn.Module):
+ def forward(self, x):
+ device = x.device
+ return torch.ones(2, 3, device=device)
+
+ inp = (torch.rand((3, 4), device="cuda:0"),)
+ self._check_equal_ts_ep_converter(Module(), inp)
+
+ def test_prim_dtype(self):
+ class Module(torch.nn.Module):
+ def forward(self, x):
+ dtype = x.dtype
+ return torch.ones(2, 3, dtype=dtype)
+
+ for dtype in [
+ torch.float32,
+ torch.double,
+ ]:
+ inp = (torch.rand((3, 4), dtype=dtype),)
+ self._check_equal_ts_ep_converter(Module(), inp)
+
+ for dtype in [
+ torch.uint8,
+ torch.int8,
+ torch.int32,
+ ]:
+ inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
+ self._check_equal_ts_ep_converter(Module(), inp)
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/_export/converter.py b/torch/_export/converter.py
index 459f534..7e68129 100644
--- a/torch/_export/converter.py
+++ b/torch/_export/converter.py
@@ -5,6 +5,7 @@
from torch.export.exported_program import ExportedProgram
from torch.export.graph_signature import (
+ ConstantArgument,
InputKind,
InputSpec,
OutputKind,
@@ -201,6 +202,20 @@
self.constant_map[name] = value
+ def convert_prim_device(self, node: torch._C.Node):
+ input_type = node.input().type()
+ if input_type.isSubtypeOf(torch._C.TensorType.get()):
+ device = input_type.device() # type: ignore[attr-defined]
+ output_name = node.output().debugName()
+ self.constant_map[output_name] = device
+ else:
+ raise ValueError(f"Unsupported JitType ({input_type}) when get device")
+
+ def convert_prim_dtype(self, node: torch._C.Node):
+ dtype = node.input().type().dtype()
+ output_name = node.output().debugName()
+ self.constant_map[output_name] = dtype
+
def convert_prim_GetAttr(self, node: torch._C.Node):
def get_attr(name: str):
if name in self.attribute_map:
@@ -350,6 +365,10 @@
elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}:
# Tuple is just a non-mutable List, so we can handle them together.
self.convert_prim_ListConstruct(node)
+ elif node_kind == "prim::device":
+ self.convert_prim_device(node)
+ elif node_kind == "prim::dtype":
+ self.convert_prim_dtype(node)
elif node_kind == "prim::DictConstruct":
self.convert_prim_DictConstruct(node)
# elif node_kind == "aten::Int":
@@ -369,17 +388,27 @@
output_name = graph_output.debugName()
if output_name in self.name_to_node:
args.append(self.name_to_node[output_name])
+ self.output_specs.append(
+ OutputSpec(
+ OutputKind.USER_OUTPUT,
+ arg=TensorArgument(name=output_name),
+ target=output_name,
+ )
+ )
+ elif output_name in self.constant_map:
+ args.append(self.constant_map[output_name])
+ self.output_specs.append(
+ OutputSpec(
+ OutputKind.USER_OUTPUT,
+ arg=ConstantArgument(
+ name=output_name, value=self.constant_map[output_name]
+ ),
+ target=output_name,
+ )
+ )
else:
raise ValueError(f"Output {output_name} not found")
- self.output_specs.append(
- OutputSpec(
- OutputKind.USER_OUTPUT,
- arg=TensorArgument(name=output_name),
- target=output_name,
- )
- )
-
self.fx_graph.output(
args[0]
) # Get rid of an extra list wrapped around final output.