Do not wrap output with input device inside _to_copy (#119868)
Fixing https://github.com/pytorch/pytorch/issues/118790
This diff revert a small part of the code that was introduced in https://github.com/pytorch/pytorch/pull/104689
The PR above added a comment that "In case of dtype promotion, fake tensor converted into tensor"
but its not always the case that a conversion in dtype causes a fake tensor to be a tensor.
When such conversion does not happen we get the following error
```
Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to
a python object of type FakeTensor
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119868
Approved by: https://github.com/ezyang, https://github.com/thiagocrepaldi
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 968a626..6202605 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -2092,6 +2092,14 @@
func()
+ def test_to(self):
+ @torch.compile(backend="eager")
+ def fn():
+ t = torch.ones(2)
+ y = t.to("meta")
+
+ fn()
+
def test_elipsis(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(a, ind, val):
diff --git a/test/dynamo_expected_failures/FakeTensorConverterTest.test_memoized_conversion_from_meta b/test/dynamo_expected_failures/FakeTensorConverterTest.test_memoized_conversion_from_meta
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/FakeTensorConverterTest.test_memoized_conversion_from_meta
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestPythonDispatch.test_subclass_autograd_device_check b/test/dynamo_expected_failures/TestPythonDispatch.test_subclass_autograd_device_check
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestPythonDispatch.test_subclass_autograd_device_check
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_deepcopy b/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_deepcopy
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_deepcopy
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_getstate_overwrite b/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_getstate_overwrite
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_getstate_overwrite
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_wrapper_serialization b/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_wrapper_serialization
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSubclassSerialization.test_tensor_subclass_wrapper_serialization
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestTorch.test_upsample_nearest1d_meta b/test/dynamo_expected_failures/TestTorch.test_upsample_nearest1d_meta
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestTorch.test_upsample_nearest1d_meta
+++ /dev/null
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 3a94322..8c89897 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1860,19 +1860,6 @@
return None
-def wrap_output_with_input_device_(x, common_device):
- # wrap meta tensor
- if common_device is not None and x.device.type == "meta":
- from torch._subclasses.fake_tensor import FakeTensorMode
-
- fake_mode = FakeTensorMode()
- fake_mode.in_kernel_invocation = True
- converter = fake_mode.fake_tensor_converter
- return converter.from_meta_and_device(fake_mode, x, common_device)
-
- return x
-
-
@register_decomposition(aten._to_copy)
@out_wrapper()
def _to_copy(
@@ -1891,19 +1878,18 @@
return x.clone()
dtype_converted = False
common_device = device_hint(x)
+
if device is not None and device != x.device:
# avoid conversions on cpu
if dtype is not None and device.type == "cpu":
x = torch._prims.convert_element_type(x, dtype)
dtype_converted = True
x = torch._prims.device_put(x, device)
+
if dtype is not None and not dtype_converted:
x = torch._prims.convert_element_type(x, dtype)
dtype_converted = True
- # In case of dtype promotion, faketensor converted into tensor.
- # Need to convert into faketensor if input was a faketensor.
- if dtype_converted:
- x = wrap_output_with_input_device_(x, common_device)
+
if memory_format is not None: # no ref/prim for memory format
return torch.clone(x, memory_format=memory_format)
return x