Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)" (#128170)

https://github.com/pytorch/pytorch/issues/128165 :(

This reverts commit a7b1dd82ff3063894fc665ab0c424815231c10e6.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128170
Approved by: https://github.com/drisspg, https://github.com/albanD
diff --git a/test/test_nn.py b/test/test_nn.py
index 6bcb401..6dfac4f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8184,9 +8184,9 @@
     @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128)
     def test_conv_empty_input(self, device, dtype):
         def help(input, conv, memory_format):
-            ref_out = conv(input).detach()
+            ref_out = conv(input)
             conv_cl = conv.to(memory_format=memory_format)
-            out_cl = conv_cl(input).detach()
+            out_cl = conv_cl(input)
             self.assertEqual(ref_out, out_cl)
             input_cl = input.to(memory_format=memory_format)
             out_cl2 = conv(input_cl)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 3d683cb..ffd429c 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -794,13 +794,6 @@
 
         should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
 
-        def compute_should_use_swap_tensors(tensor, tensor_applied):
-            return (should_use_swap_tensors
-                    # subclasses may have multiple child tensors so we need to use swap_tensors
-                    or is_traceable_wrapper_subclass(tensor_applied)
-                    or tensor.device.type == 'xla'
-                    or tensor_applied.device.type == 'xla')
-
         for key, param in self._parameters.items():
             if param is None:
                 continue
@@ -811,7 +804,8 @@
                 param_applied = fn(param)
             p_should_use_set_data = compute_should_use_set_data(param, param_applied)
 
-            p_should_use_swap_tensors = compute_should_use_swap_tensors(param, param_applied)
+            # subclasses may have multiple child tensors so we need to use swap_tensors
+            p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
 
             param_grad = param.grad
             if p_should_use_swap_tensors: