[export] Device remapping in export (#133660)
Implemented `move_to_device_pass()` function in `torch._export.passes`.
The user has to explicitly call this method to move the exported program from one torch.device to another one.
Fixes https://github.com/pytorch/pytorch/issues/121761
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133660
Approved by: https://github.com/angelayi
diff --git a/test/export/test_passes.py b/test/export/test_passes.py
index e30f634..7a57881 100644
--- a/test/export/test_passes.py
+++ b/test/export/test_passes.py
@@ -18,6 +18,7 @@
_gather_constant_attrs,
)
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
+from torch._export.passes.move_to_device_pass import move_to_device_pass
from torch._export.passes.replace_set_grad_with_hop_pass import (
_is_set_grad_enabled_node,
_is_set_grad_enabled_sub_mod,
@@ -45,6 +46,7 @@
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupport
from torch.library import _scoped_library, impl
+from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
@@ -1179,6 +1181,42 @@
self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
+ @unittest.skipIf(not TEST_CUDA, "requires cuda")
+ def test_move_to_device_pass(self):
+ class Model(torch.nn.Module):
+ def __init__(self, size=4, h_dim=10):
+ super().__init__()
+ self.rnn = torch.nn.GRU(size, h_dim, batch_first=True)
+
+ def forward(self, x):
+ _, states = self.rnn(x)
+ return states
+
+ # move the exported program from cpu to cuda:0
+ mod = Model()
+ example_inputs = (torch.rand(1, 10, 4),)
+ ep = export(mod, example_inputs)
+ location = torch.device("cuda:0")
+ ep = move_to_device_pass(ep, location=location)
+ gm = ep.module()
+ test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),)
+ outputs = gm(*test_inputs)
+ self.assertEqual(outputs.device, torch.device("cuda:0"))
+ # move it back to cpu
+ location = "cpu"
+ ep = move_to_device_pass(ep, location=location)
+ gm = ep.module()
+ test_inputs = (torch.rand(1, 10, 4).to("cpu"),)
+ outputs = gm(*test_inputs)
+ self.assertEqual(outputs.device, torch.device("cpu"))
+ # move it to cuda:0 again
+ location = {"cpu": "cuda:0"}
+ ep = move_to_device_pass(ep, location=location)
+ gm = ep.module()
+ test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),)
+ outputs = gm(*test_inputs)
+ self.assertEqual(outputs.device, torch.device("cuda:0"))
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/_export/passes/move_to_device_pass.py b/torch/_export/passes/move_to_device_pass.py
new file mode 100644
index 0000000..d76f877
--- /dev/null
+++ b/torch/_export/passes/move_to_device_pass.py
@@ -0,0 +1,66 @@
+from typing import Dict, Union
+
+import torch
+import torch.utils._pytree as pytree
+from torch.export import ExportedProgram
+
+
+def _get_new_device(
+ curr_device: torch.device,
+ location: Union[torch.device, str, Dict[str, str]],
+) -> str:
+ if isinstance(location, dict):
+ if str(curr_device) in location.keys():
+ return location[str(curr_device)]
+ else:
+ return str(curr_device)
+ else:
+ return str(location)
+
+
+def move_to_device_pass(
+ ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
+) -> ExportedProgram:
+ """
+ Move the exported program to the given device.
+
+ Args:
+ ep (ExportedProgram): The exported program to move.
+ location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
+ If a string, it is interpreted as a device name.
+ If a dict, it is interpreted as a mapping from
+ the existing device to the intended one
+
+ Returns:
+ ExportedProgram: The moved exported program.
+ """
+ # move all the state_dict
+ for k, v in ep.state_dict.items():
+ if isinstance(v, torch.nn.Parameter):
+ ep._state_dict[k] = torch.nn.Parameter(
+ v.to(_get_new_device(v.device, location))
+ )
+ else:
+ ep._state_dict[k] = v.to(_get_new_device(v.device, location))
+
+ # move all the constants
+ for k, v in ep.constants.items():
+ if isinstance(v, torch.Tensor):
+ ep._constants[k] = v.to(_get_new_device(v.device, location))
+
+ for node in ep.graph.nodes:
+ # move all the nodes kwargs with burnt-in device
+ if "device" in node.kwargs:
+ kwargs = node.kwargs.copy()
+ kwargs["device"] = _get_new_device(kwargs["device"], location)
+ node.kwargs = kwargs
+ # move all the tensor metadata
+ node.meta["val"] = pytree.tree_map(
+ lambda v: v.to(_get_new_device(v.device, location))
+ if isinstance(v, torch.Tensor)
+ else v,
+ node.meta.get("val"),
+ )
+
+ ep.validate()
+ return ep