Add to_empty() suggestion in the error message (#119353)
Fixes #119293, the comprehensive documentation is [here](https://github.com/pytorch/pytorch/blob/0f478d9d610bec80b8a7517103d9390cc26d1d05/docs/source/meta.rst#id11).
Just added the suggestion into the error message so it is more informative to user.
@albanD
Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119353
Approved by: https://github.com/mikaylagawarecki
diff --git a/test/test_nn.py b/test/test_nn.py
index 3327e35..23376f5 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -12097,6 +12097,17 @@
m = MyModule(10, 1, device='meta', dtype=dtype)
m(input)
+ # Test empty meta module error with torch.nn.Module.to().
+ with self.assertRaisesRegex(
+ NotImplementedError,
+ re.escape(
+ "Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() "
+ "instead of torch.nn.Module.to() when moving module from meta to a different "
+ "device."
+ ),
+ ):
+ m.to(device)
+
# Test materializing meta module on a real device.
m.to_empty(device=device)
m(input)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 3cec9f6..e7b838e 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -1169,10 +1169,27 @@
"if a complex module does not work as expected.")
def convert(t):
- if convert_to_format is not None and t.dim() in (4, 5):
- return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
- non_blocking, memory_format=convert_to_format)
- return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
+ try:
+ if convert_to_format is not None and t.dim() in (4, 5):
+ return t.to(
+ device,
+ dtype if t.is_floating_point() or t.is_complex() else None,
+ non_blocking,
+ memory_format=convert_to_format,
+ )
+ return t.to(
+ device,
+ dtype if t.is_floating_point() or t.is_complex() else None,
+ non_blocking,
+ )
+ except NotImplementedError as e:
+ if str(e) == "Cannot copy out of meta tensor; no data!":
+ raise NotImplementedError(
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
+ f"when moving module from meta to a different device."
+ ) from None
+ else:
+ raise
return self._apply(convert)