Add `torch.__future__._overwrite_module_params_on_conversion` global flag, and check it in `nn.Module._apply()` (#21613)

Summary:
https://github.com/pytorch/pytorch/pull/17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...
```

yf225 TODO: fix the description here when we finish the implementation

To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.

We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. resistor ailzhang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21613

Differential Revision: D15895387

Pulled By: yf225

fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp
index 1a3f7bd..f468ff1 100644
--- a/aten/src/ATen/native/TypeProperties.cpp
+++ b/aten/src/ATen/native/TypeProperties.cpp
@@ -38,6 +38,11 @@
   return self.is_quantized();
 }
 
+// True if `self` has the same derived type of TensorImpl as `other`.
+bool _has_same_tensorimpl_type(const Tensor& self, const Tensor& other) {
+  return typeid(*(self.unsafeGetTensorImpl())) == typeid(*(other.unsafeGetTensorImpl()));
+}
+
 Tensor type_as(const Tensor& self, const Tensor& other) {
   return self.toType(other.type());
 }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 74fd2e9..3293c65 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2036,6 +2036,9 @@
 - func: type_as(Tensor self, Tensor other) -> Tensor
   variants: method
 
+- func: _has_same_tensorimpl_type(Tensor self, Tensor other) -> bool
+  variants: function
+
 - func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
   variants: function
   dispatch:
diff --git a/test/test_nn.py b/test/test_nn.py
index f26d97d..5e91c85 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1602,6 +1602,144 @@
         with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
             pgm.backward(torch.randn(10, 20))
 
+    def test_overwrite_module_params_on_conversion(self):
+        torch.__future__.set_overwrite_module_params_on_conversion(False)
+
+        # Test that if the conversion function passed to `module._apply()`
+        # changes the TensorImpl type of `module`'s parameters, the `module`'s
+        # parameters are always overwritten, regardless of the value of
+        # `torch.__future__.get_overwrite_module_params_on_conversion()`.
+        m = nn.Linear(20, 10)
+        m.weight.grad = torch.randn(10, 20)
+        weight_ref = m.weight
+        weight_grad_ref = m.weight.grad
+        m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20])))
+        self.assertNotEqual(weight_ref.layout, m.weight.layout)
+        self.assertNotEqual(weight_grad_ref.layout, m.weight.grad.layout)
+
+        # Test that under the current default settings
+        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
+        # a view to a module's parameters is not pointing to the same storage as
+        # its base variable after converting the module to a different dtype.
+        m = nn.Linear(20, 10).float()
+        mw = m.weight[:]
+        m.double()
+        mw[0][0] = 5
+        with self.assertRaisesRegex(RuntimeError, "Expected object of scalar type Float but got scalar type Double"):
+            mw[0][0] == mw._base[0][0]
+
+        torch.__future__.set_overwrite_module_params_on_conversion(True)
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # a view to a module's parameters is still pointing to the same storage as
+        # its base variable after converting the module to a different dtype.
+        m = nn.Linear(20, 10).float()
+        mw = m.weight[:]
+        m.double()
+        mw[0][0] = 5
+        self.assertTrue(mw[0][0] == mw._base[0][0])
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # `float_module.double()` doesn't preserve previous references to
+        # `float_module`'s parameters or gradients.
+        m = nn.Linear(20, 10).float()
+        m.weight.grad = torch.randn(10, 20).float()
+        weight_ref = m.weight
+        weight_grad_ref = m.weight.grad
+        m.double()
+        self.assertNotEqual(weight_ref.dtype, m.weight.dtype)
+        self.assertNotEqual(weight_grad_ref.dtype, m.weight.grad.dtype)
+
+        def add_one_inplace(t):
+            return t.add_(1.0)
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # applying an in-place operation to a module would bump the module's
+        # original parameters' version counter.
+        m = nn.Linear(20, 10)
+        pvm = m.weight.mul(m.weight)
+        weight_ref = m.weight
+        m_weight_version_saved = weight_ref._version
+        m = m._apply(add_one_inplace)
+        # Test that the in-place operation bumps the original parameter's version counter
+        self.assertGreater(weight_ref._version, m_weight_version_saved)
+        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
+            pvm.backward(torch.randn(10, 20))
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # applying an in-place operation to a module would bump the module's
+        # original parameters' gradients' version counter.
+        m = nn.Linear(20, 10)
+        m.weight.grad = torch.randn(10, 20).requires_grad_()
+        pgm = m.weight.grad.mul(m.weight.grad)
+        weight_grad_ref = m.weight.grad
+        m_weight_grad_version_saved = weight_grad_ref._version
+        m = m._apply(add_one_inplace)
+        self.assertGreater(weight_grad_ref._version, m_weight_grad_version_saved)
+        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
+            pgm.backward(torch.randn(10, 20))
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # applying an out-of-place operation to a module doesn't bump
+        # the module's original parameters' version counter.
+        m = nn.Linear(20, 10)
+        weight_ref = m.weight
+        m_weight_version_saved = weight_ref._version
+        m = m._apply(lambda t: torch.randn(t.shape))
+        self.assertEqual(weight_ref._version, m_weight_version_saved)
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # applying an out-of-place operation to a module doesn't bump
+        # the module's original parameters' gradients' version counter.
+        m = nn.Linear(20, 10)
+        m.weight.grad = torch.randn(10, 20).requires_grad_()
+        weight_grad_ref = m.weight.grad
+        m_weight_grad_version_saved = weight_grad_ref._version
+        m = m._apply(lambda t: torch.randn(t.shape))
+        self.assertEqual(weight_grad_ref._version, m_weight_grad_version_saved)
+
+        torch.__future__.set_overwrite_module_params_on_conversion(False)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_overwrite_module_params_on_conversion_cpu_cuda(self):
+        torch.__future__.set_overwrite_module_params_on_conversion(False)
+
+        # Test that under the current default settings
+        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
+        # a view to a module's parameters is not pointing to the same storage as
+        # its base variable after converting the module to a different device.
+        m = nn.Linear(20, 10)
+        mw = m.weight[:]
+        m.to('cuda')
+        with torch.no_grad():
+            # Without using `torch.no_grad()`, this will leak CUDA memory.
+            # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875)
+            mw[0][0] = 5
+        with self.assertRaisesRegex(RuntimeError, "Expected object of backend CPU but got backend CUDA"):
+            mw[0][0] == mw._base[0][0]
+
+        torch.__future__.set_overwrite_module_params_on_conversion(True)
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # a view to a module's parameters is still pointing to the same storage as
+        # its base variable after converting the module to a different device.
+        m = nn.Linear(20, 10)
+        mw = m.weight[:]
+        m.to('cuda')
+        mw[0][0] = 5
+        self.assertTrue(mw[0][0] == mw._base[0][0])
+
+        # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
+        # `cpu_module.to("cuda")` doesn't preserve previous references to
+        # `cpu_module`'s parameters or gradients.
+        m = nn.Linear(20, 10)
+        m.weight.grad = torch.randn(10, 20)
+        weight_ref = m.weight
+        weight_grad_ref = m.weight.grad
+        m.to('cuda')
+        self.assertNotEqual(weight_ref.device, m.weight.device)
+        self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device)
+
     def test_type(self):
         l = nn.Linear(10, 20)
         net = nn.Module()
diff --git a/torch/__future__.py b/torch/__future__.py
new file mode 100644
index 0000000..789ec65
--- /dev/null
+++ b/torch/__future__.py
@@ -0,0 +1,19 @@
+"""
+This global flag controls whether to assign new tensors to the parameters
+instead of changing the existing parameters in-place when converting an `nn.Module`
+using the following methods:
+1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
+2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
+3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
+4. `module._apply(fn)` (for generic functions applied to `module`)
+
+Default: False
+"""
+_overwrite_module_params_on_conversion = False
+
+def set_overwrite_module_params_on_conversion(value):
+    global _overwrite_module_params_on_conversion
+    _overwrite_module_params_on_conversion = value
+
+def get_overwrite_module_params_on_conversion():
+    return _overwrite_module_params_on_conversion
diff --git a/torch/__init__.py b/torch/__init__.py
index 663b7e2..543ecae 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -310,6 +310,7 @@
 import torch.backends.mkl
 import torch.backends.openmp
 import torch.__config__
+import torch.__future__
 
 _C._init_names(list(torch._storage_classes))
 
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 2624321..44f1f11 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -88,7 +88,7 @@
   // from `new_data` to `var`. It requires that `new_data` has the same derived
   // type of TensorImpl as `var`.
   TORCH_CHECK(
-    typeid(*(this->unsafeGetTensorImpl())) == typeid(*(new_data.unsafeGetTensorImpl())),
+    _has_same_tensorimpl_type(*this, new_data),
     "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have different types of TensorImpl.");
 
   // Resets gradient accumulator if metadata is out of date
diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h
index 052d4c2..bda04e95 100644
--- a/torch/csrc/autograd/variable.h
+++ b/torch/csrc/autograd/variable.h
@@ -248,10 +248,9 @@
       bool keep_graph,
       bool create_graph) const;
 
-  /// Sets the `Tensor` held by this `Variable` to the one supplied.
-  /// It is rarely necessary to call this; it's used, for example, when
-  /// a non-sparse gradient gets added to a sparse gradient, requiring
-  /// the type of the gradient `Variable` to become non-sparse.
+  /// Sets the tensor data held by this `Variable` to be the same as `new_data`.
+  /// It requires that `new_data` has the same derived type of TensorImpl as
+  /// this `Variable`, by checking `_has_same_tensorimpl_type(this, new_data)`.
   void set_data(const at::Tensor &new_data);
 
   /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index d94b938..ad30ecd 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -193,15 +193,44 @@
         for module in self.children():
             module._apply(fn)
 
-        for param in self._parameters.values():
+        def compute_should_use_set_data(tensor, tensor_applied):
+            if torch._has_same_tensorimpl_type(tensor, tensor_applied):
+                # If the new tensor has the same TensorImpl type as the existing tensor,
+                # the current behavior is to change the tensor in-place using `.data =`,
+                # and the future behavior is to overwrite the existing tensor. However,
+                # changing the current behavior is a BC-breaking change, and we want it
+                # to happen in future releases. So for now we introduce the
+                # `torch.__future__.get_overwrite_module_params_on_conversion()`
+                # global flag to let the user control whether they want the future
+                # behavior of overwriting the existing tensor or not.
+                return not torch.__future__.get_overwrite_module_params_on_conversion()
+            else:
+                return False
+
+        for key, param in self._parameters.items():
             if param is not None:
+                # Tensors stored in modules are graph leaves, and we don't want to
+                # track autograd history of `param_applied`, so we have to use
+                # `with torch.no_grad():`
                 with torch.no_grad():
                     param_applied = fn(param)
-                param.data = param_applied
-                if param._grad is not None:
+                should_use_set_data = compute_should_use_set_data(param, param_applied)
+                if should_use_set_data:
+                    param.data = param_applied
+                else:
+                    assert isinstance(param, Parameter)
+                    assert param.is_leaf
+                    self._parameters[key] = Parameter(param_applied, param.requires_grad)
+
+                if param.grad is not None:
                     with torch.no_grad():
-                        grad_applied = fn(param._grad)
-                    param._grad.data = grad_applied
+                        grad_applied = fn(param.grad)
+                    should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
+                    if should_use_set_data:
+                        param.grad.data = grad_applied
+                    else:
+                        assert param.grad.is_leaf
+                        self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
 
         for key, buf in self._buffers.items():
             if buf is not None: