[torchscript] Handle prim::device and prim::dtype (#127466)

- Support prim::device and prim::dtype during torchscript migration to export
- Add unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127466
Approved by: https://github.com/SherlockNoMad
diff --git a/test/export/test_converter.py b/test/export/test_converter.py
index b6d0e54..64cea8c 100644
--- a/test/export/test_converter.py
+++ b/test/export/test_converter.py
@@ -1,5 +1,7 @@
 # Owner(s): ["oncall: export"]
 
+import unittest
+
 import torch
 
 import torch.utils._pytree as pytree
@@ -9,6 +11,8 @@
 
 from torch.testing._internal.common_utils import run_tests
 
+requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
+
 
 class TestConverter(TestCase):
     def _check_equal_ts_ep_converter(self, mod, inp):
@@ -64,6 +68,46 @@
         self._check_equal_ts_ep_converter(MOutputTuple(), inp)
         self._check_equal_ts_ep_converter(MOutputDict(), inp)
 
+    def test_prim_device(self):
+        class Module(torch.nn.Module):
+            def forward(self, x):
+                device = x.device
+                return torch.ones(2, 3, device=device)
+
+        inp = (torch.rand(3, 4),)
+        self._check_equal_ts_ep_converter(Module(), inp)
+
+    @requires_cuda
+    def test_prim_device_cuda(self):
+        class Module(torch.nn.Module):
+            def forward(self, x):
+                device = x.device
+                return torch.ones(2, 3, device=device)
+
+        inp = (torch.rand((3, 4), device="cuda:0"),)
+        self._check_equal_ts_ep_converter(Module(), inp)
+
+    def test_prim_dtype(self):
+        class Module(torch.nn.Module):
+            def forward(self, x):
+                dtype = x.dtype
+                return torch.ones(2, 3, dtype=dtype)
+
+        for dtype in [
+            torch.float32,
+            torch.double,
+        ]:
+            inp = (torch.rand((3, 4), dtype=dtype),)
+            self._check_equal_ts_ep_converter(Module(), inp)
+
+        for dtype in [
+            torch.uint8,
+            torch.int8,
+            torch.int32,
+        ]:
+            inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
+            self._check_equal_ts_ep_converter(Module(), inp)
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/_export/converter.py b/torch/_export/converter.py
index 459f534..7e68129 100644
--- a/torch/_export/converter.py
+++ b/torch/_export/converter.py
@@ -5,6 +5,7 @@
 
 from torch.export.exported_program import ExportedProgram
 from torch.export.graph_signature import (
+    ConstantArgument,
     InputKind,
     InputSpec,
     OutputKind,
@@ -201,6 +202,20 @@
 
         self.constant_map[name] = value
 
+    def convert_prim_device(self, node: torch._C.Node):
+        input_type = node.input().type()
+        if input_type.isSubtypeOf(torch._C.TensorType.get()):
+            device = input_type.device()  # type: ignore[attr-defined]
+            output_name = node.output().debugName()
+            self.constant_map[output_name] = device
+        else:
+            raise ValueError(f"Unsupported JitType ({input_type}) when get device")
+
+    def convert_prim_dtype(self, node: torch._C.Node):
+        dtype = node.input().type().dtype()
+        output_name = node.output().debugName()
+        self.constant_map[output_name] = dtype
+
     def convert_prim_GetAttr(self, node: torch._C.Node):
         def get_attr(name: str):
             if name in self.attribute_map:
@@ -350,6 +365,10 @@
         elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}:
             # Tuple is just a non-mutable List, so we can handle them together.
             self.convert_prim_ListConstruct(node)
+        elif node_kind == "prim::device":
+            self.convert_prim_device(node)
+        elif node_kind == "prim::dtype":
+            self.convert_prim_dtype(node)
         elif node_kind == "prim::DictConstruct":
             self.convert_prim_DictConstruct(node)
         # elif node_kind == "aten::Int":
@@ -369,17 +388,27 @@
             output_name = graph_output.debugName()
             if output_name in self.name_to_node:
                 args.append(self.name_to_node[output_name])
+                self.output_specs.append(
+                    OutputSpec(
+                        OutputKind.USER_OUTPUT,
+                        arg=TensorArgument(name=output_name),
+                        target=output_name,
+                    )
+                )
+            elif output_name in self.constant_map:
+                args.append(self.constant_map[output_name])
+                self.output_specs.append(
+                    OutputSpec(
+                        OutputKind.USER_OUTPUT,
+                        arg=ConstantArgument(
+                            name=output_name, value=self.constant_map[output_name]
+                        ),
+                        target=output_name,
+                    )
+                )
             else:
                 raise ValueError(f"Output {output_name} not found")
 
-            self.output_specs.append(
-                OutputSpec(
-                    OutputKind.USER_OUTPUT,
-                    arg=TensorArgument(name=output_name),
-                    target=output_name,
-                )
-            )
-
         self.fx_graph.output(
             args[0]
         )  # Get rid of an extra list wrapped around final output.