Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)" (#128170)
https://github.com/pytorch/pytorch/issues/128165 :(
This reverts commit a7b1dd82ff3063894fc665ab0c424815231c10e6.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128170
Approved by: https://github.com/drisspg, https://github.com/albanD
diff --git a/test/test_nn.py b/test/test_nn.py
index 6bcb401..6dfac4f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8184,9 +8184,9 @@
@dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128)
def test_conv_empty_input(self, device, dtype):
def help(input, conv, memory_format):
- ref_out = conv(input).detach()
+ ref_out = conv(input)
conv_cl = conv.to(memory_format=memory_format)
- out_cl = conv_cl(input).detach()
+ out_cl = conv_cl(input)
self.assertEqual(ref_out, out_cl)
input_cl = input.to(memory_format=memory_format)
out_cl2 = conv(input_cl)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 3d683cb..ffd429c 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -794,13 +794,6 @@
should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
- def compute_should_use_swap_tensors(tensor, tensor_applied):
- return (should_use_swap_tensors
- # subclasses may have multiple child tensors so we need to use swap_tensors
- or is_traceable_wrapper_subclass(tensor_applied)
- or tensor.device.type == 'xla'
- or tensor_applied.device.type == 'xla')
-
for key, param in self._parameters.items():
if param is None:
continue
@@ -811,7 +804,8 @@
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
- p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied)
+ # subclasses may have multiple child tensors so we need to use swap_tensors
+ p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
param_grad = param.grad
if p_should_use_swap_tensors: