Improve testing of inplace views (#59891)

Summary:
Partially addresses https://github.com/pytorch/pytorch/issues/49825 by improving the testing
 - Rename some of the old tests that had "inplace_view" in their names, but actually mean "inplace_[update_]on_view" so there is no confusion with the naming
 - Adds some tests in test_view_ops that verify basic behavior
 - Add tests that creation meta is properly handled for no-grad, multi-output, and custom function cases
 - Add test that verifies that in the cross dtype view case, the inplace views won't be accounted in the backward graph on rebase as mentioned in the issue.
 - Update inference mode tests to also check in-place

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59891

Reviewed By: albanD

Differential Revision: D29272546

Pulled By: soulitzer

fbshipit-source-id: b12acf5f0e3f788167ebe268423cdb58481b56f6
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 47fc030..b0b0fcc 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -3407,7 +3407,7 @@
         test_reduction(torch.cumprod, False)
         test_reduction(torch.logcumsumexp, False, takes_dtype=False)
 
-    def test_inplace_view_saved_output(self):
+    def test_inplace_on_view_saved_output(self):
         # Test an in-place operation on a view in which the in-place op saves
         # its output. Previously, this created a reference cycle.
         dealloc = [0]
@@ -3426,7 +3426,7 @@
         test()
         self.assertEqual(dealloc[0], 1)
 
-    def test_inplace_view_leaf_errors(self):
+    def test_inplace_on_view_leaf_errors(self):
         # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
         x = torch.zeros(1, requires_grad=True)
         y = x.view_as(x)
@@ -3436,7 +3436,7 @@
                                     "an in-place operation."):
             y.add_(1)
 
-    def test_inplace_view_backward(self):
+    def test_inplace_on_view_backward(self):
         # Issue #10532: Make sure that this does not raise RuntimeError.
         net = nn.Sequential(
             nn.InstanceNorm2d(2),
@@ -3465,7 +3465,7 @@
         fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
         self.assertEqual(fn.name(), "ThresholdBackwardBackward")
 
-    def test_inplace_view_weak_grad_fn(self):
+    def test_inplace_on_view_weak_grad_fn(self):
         # Issue 23502: Test that b's grad_fn is preserved.
         a = torch.arange(10.0, requires_grad=True)
 
@@ -4919,7 +4919,7 @@
                 res.select(0, 0).copy_(grad)
                 return res, None
 
-        fn_id_to_inplace_view_err_msg = {
+        fn_id_to_inplace_on_view_err_msg = {
             "one_output": ("Output 0 of IdOneOutputBackward is a view and is being "
                            "modified inplace. This view was created inside a custom Function"),
             "two_output": ("Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
@@ -4962,7 +4962,7 @@
                     a = torch.ones(2, dtype=dtype, requires_grad=True)
                     b = torch.ones(2, dtype=dtype, requires_grad=True)
 
-                    err_msg = fn_id_to_inplace_view_err_msg[fn_id]
+                    err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
 
                     if not inplace or not output_is_a_view:
                         gradcheck(fn, (a, b), check_batched_grad=False)
@@ -4990,7 +4990,119 @@
         self._do_test_autograd_simple_views_python(torch.double)
         self._do_test_autograd_simple_views_python(torch.cdouble)
 
-    def test_autograd_complex_views_python(self):
+    def test_autograd_inplace_views_creation_meta(self):
+        # Tests creation_meta properly handled for inplace views
+
+        class Func(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x.view_as(x)
+
+            @staticmethod
+            def backward(ctx, x):
+                return x
+        view_custom = Func.apply
+
+        def run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2):
+            # This test checks the behavior of inplace-view functions when
+            # the views are created in grad mode or not
+            base = torch.rand(2, 3, requires_grad=requires_grad).clone()
+            # 1. Create a view with `grad_mode=grad_mode_view`
+            with torch.set_grad_enabled(grad_mode_view):
+                if fn_type == "multi_view":
+                    inp = base.unbind()[0]
+                elif fn_type == "custom" :
+                    inp = view_custom(base)
+                else:
+                    inp = base.view_as(base)
+
+            # 2. Perform inplace view with `grad_mode=grad_mode_iview`
+            with torch.set_grad_enabled(grad_mode_iview):
+                if error1 is not None:
+                    with self.assertRaisesRegex(RuntimeError, error1):
+                        fn(inp)
+                    return
+                else:
+                    # If error is None, check that runs without error
+                    fn(inp)
+            # 3. Do inplace on the (new) view
+            if error2 is not None:
+                with self.assertRaisesRegex(RuntimeError, error2):
+                    inp.add_(1)
+            else:
+                # If error is None, check that runs without error
+                inp.add_(1)
+
+        no_grad_err = "A view was created in no_grad mode"
+        multi_view_err = "function that returns multiple views"
+        custom_err = "view was created inside a custom Function"
+
+        def run_tests(fn):
+            for fn_type in ("normal", "multi_view", "custom"):
+                for grad_mode_view in (True, False):
+                    for grad_mode_iview in (True, False):
+                        for requires_grad in (True, False):
+                            error1 = None  # expected error when we do inplace_view on original view
+                            error2 = None  # expected error when we do inplace on the resulting view
+
+                            if requires_grad:
+                                if not grad_mode_view and grad_mode_iview:
+                                    error1 = no_grad_err
+                                if not grad_mode_view and not grad_mode_iview:
+                                    error2 = no_grad_err
+
+                                if fn_type == "multi_view":
+                                    if grad_mode_view and grad_mode_iview:
+                                        error1 = multi_view_err
+                                    if grad_mode_view and not grad_mode_iview:
+                                        error2 = multi_view_err
+
+                                if fn_type == "custom":
+                                    if grad_mode_view and grad_mode_iview:
+                                        error1 = custom_err
+                                    if grad_mode_view and not grad_mode_iview:
+                                        error2 = custom_err
+
+                            run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2)
+
+        # This list was created by logging gen_inplace_or_view_type.py
+        #   detach_ is excluded for this test because it cannot be applied to
+        #   views and thus does not return a view
+        run_tests(lambda v: v.as_strided_((1, 0), (2, 2)))
+        run_tests(lambda v: v.transpose_(0, 0))
+        run_tests(lambda v: v.t_())
+        run_tests(lambda v: v.squeeze_(0))
+        run_tests(lambda v: v.unsqueeze_(0))
+        run_tests(lambda v: v.swapdims_(0, 0))
+        run_tests(lambda v: v.swapaxes_(0, 0))
+
+    # TODO This is not the correct behavior -
+    # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627
+    def test_autograd_inplace_views_cross_dtype(self):
+        # This test is here to make sure that any change to this behavior is detected
+        # and not silent. The TODOs below mark the places with unexpected behavior.
+        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
+        a = a_orig.clone()
+        b = torch.view_as_real(a)
+        b = b.transpose(0, 1)
+        b += 1
+        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
+        non_inplace_grad = a_orig.grad
+
+        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
+        a = a_orig.clone()
+        b = torch.view_as_real(a)
+        b.transpose_(0, 1)
+        b += 1
+        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
+        inplace_grad = a_orig.grad
+
+        # TODO: this is a bug!
+        # once this is fixed, it should have the transpose removed:
+        # self.assertTrue(torch.allclose(non_inplace_grad, inplace_grad))
+        self.assertEqual(non_inplace_grad.T, inplace_grad)
+
+    def test_autograd_multiple_views_python(self):
         # This is not necessarily the absolute correct behavior, but this is the current
         # one. This test is here to make sure that any change to this behavior is detected
         # and not silent. The TODOs below mark the places with unexpected behavior.
@@ -5032,7 +5144,7 @@
                                     "Output 0 of ComplexViewBackward is a view and is being modified inplace"):
             out += 1
 
-    def test_autograd_inplace_views_python(self):
+    def test_autograd_python_custom_function_inplace(self):
         # This is not necessarily the absolute correct behavior, but this is the current
         # one. This test is here to make sure that any change to this behavior is detected
         # and not silent. The TODOs below mark the places with unexpected behavior.
@@ -8078,7 +8190,7 @@
         # gpu thread ReadyQueue
         out.sum().backward()
 
-    def test_inplace_view_backprop_base(self, device):
+    def test_inplace_on_view_backprop_base(self, device):
         # modify view and back-prop through base
         root = torch.randn(2, 2, device=device, requires_grad=True)
         x = root.clone()
@@ -8087,7 +8199,7 @@
         x.sum().backward()
         self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
 
-    def test_inplace_view_backprop_view_of_view(self, device):
+    def test_inplace_on_view_backprop_view_of_view(self, device):
         # modify view and backprop through view-of-view
         root = torch.randn(2, 2, device=device, requires_grad=True)
         x = root.clone()
@@ -8097,7 +8209,7 @@
         v2.sum().backward()
         self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
 
-    def test_inplace_view_of_view(self, device):
+    def test_inplace_on_view_of_view(self, device):
         # modify view-of-view and backprop through base
         root = torch.randn(2, 2, device=device, requires_grad=True)
         x = root.clone()
@@ -8107,7 +8219,7 @@
         x.sum().backward()
         self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
 
-    def test_inplace_view_then_no_grad(self, device):
+    def test_inplace_on_view_then_no_grad(self, device):
         # Perform an in-place operation on a view of a non-leaf variable.
         a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
         b = a * 2
@@ -8120,7 +8232,7 @@
 
         c.sum().backward()
 
-    def test_inplace_view_gradcheck(self, device):
+    def test_inplace_on_view_gradcheck(self, device):
         # gradcheck modifications to views
         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
         b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8135,14 +8247,14 @@
         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
         gradgradcheck(func, (a, b), (go,))
 
-    def test_inplace_view_multiple_outputs(self, device):
+    def test_inplace_on_view_multiple_outputs(self, device):
         root = torch.arange(9., dtype=torch.double).reshape(3, 3).requires_grad_()
         x = root.clone()
         v1 = x.unbind()
         with self.assertRaises(RuntimeError):
             v1[0].mul_(2)
 
-    def test_inplace_view_of_multiple_output_view(self, device):
+    def test_inplace_on_view_of_multiple_output_view(self, device):
         a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
         b = a.unbind(0)
         c = b[0].view_as(b[0])
@@ -8156,7 +8268,7 @@
         with self.assertRaises(RuntimeError):
             c[0].mul_(2)
 
-    def test_inplace_view_makes_base_require_grad(self, device):
+    def test_inplace_on_view_makes_base_require_grad(self, device):
         # in-place modification to view makes base require grad
         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
         b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8172,7 +8284,7 @@
         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
         gradgradcheck(func, (a, b), (go,))
 
-    def test_inplace_view_backprop_view(self, device):
+    def test_inplace_on_view_backprop_view(self, device):
         # modify view and backprop through view
         a = torch.tensor([2., 5.], device=device, requires_grad=False)
         b = torch.tensor([3.], device=device, requires_grad=True)
@@ -8181,7 +8293,7 @@
         self.assertEqual(b.grad.tolist(), [5])
         self.assertIsNone(a.grad)
 
-    def test_inplace_view_modify_base(self, device):
+    def test_inplace_on_view_modify_base(self, device):
         # Test that an in-place operation on a base that forced it to require
         # grad also forces any previous views to require grad and backprop
         # correctly
@@ -8199,7 +8311,7 @@
         gradcheck(fn, [r])
         gradgradcheck(fn, [r])
 
-    def test_inplace_view_python(self, device):
+    def test_inplace_on_view_python(self, device):
         # in-place modifications of Python-autograd created view
         a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
         b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8225,7 +8337,7 @@
         go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
         gradgradcheck(func, (a, b), (go,))
 
-    def test_inplace_view_non_contig(self, device):
+    def test_inplace_on_view_non_contig(self, device):
         root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True)
         x = root.clone()
         v1 = x.narrow(0, 0, 1)
@@ -8234,7 +8346,7 @@
         x.sum().backward()
         self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])
 
-    def test_inplace_view_multi_output_unsafe(self, device):
+    def test_inplace_on_view_multi_output_unsafe(self, device):
         for f in [lambda t: t.unsafe_split(1),
                   lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
                   lambda t: t.unsafe_chunk(3)]:
@@ -8244,7 +8356,7 @@
             s1.mul_(s2)
             s1.sum().backward()
 
-    def test_inplace_view_multi_output_safe(self, device):
+    def test_inplace_on_view_multi_output_safe(self, device):
         for f in [lambda t: t.split(1),
                   lambda t: t.split_with_sizes((1, 1, 1)),
                   lambda t: t.chunk(3)]:
@@ -8482,15 +8594,18 @@
                 self.assertFalse(func_out.requires_grad)
 
     def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self):
-        with torch.inference_mode():
+        @torch.inference_mode()
+        def run_test(fn):
             for requires_grad in (True, False):
                 c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 
-                # after perform inplace operation, tensor is still
+                # after performing inplace operation, tensor is still
                 # an inference tensor
-                c.add_(2)
+                fn(c)
                 self.assertTrue(torch.is_inference(c))
                 self.assertEqual(c.requires_grad, requires_grad)
+        run_test(lambda x: x.add_(2))
+        run_test(lambda x: x.transpose_(0, 1))
 
     def test_inference_mode_inf_tensor_in_inf_mode_view_op(self):
         with torch.inference_mode():
@@ -8517,18 +8632,21 @@
         self.assertTrue(func_out.is_leaf)
 
     def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
-        for requires_grad in (False, True):
-            with torch.inference_mode():
-                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
+        def run_test(fn):
+            for requires_grad in (False, True):
+                with torch.inference_mode():
+                    c = torch.ones(1, 2, 3, requires_grad=requires_grad)
 
-            if requires_grad:
-                # leaf variable that requires grad is being used in an inplace
-                # operation when requires_grad=True
-                pass
-            else:
-                err_msg = "Inplace update to inference tensor outside InferenceMode"
-                with self.assertRaisesRegex(RuntimeError, err_msg):
-                    c.add_(2)
+                if requires_grad:
+                    # leaf variable that requires grad is being used in an inplace
+                    # operation when requires_grad=True
+                    pass
+                else:
+                    err_msg = "Inplace update to inference tensor outside InferenceMode"
+                    with self.assertRaisesRegex(RuntimeError, err_msg):
+                        fn(c)
+        run_test(lambda x: x.add_(2))
+        run_test(lambda x: x.transpose_(0, 1))
 
     def test_inference_mode_inf_tensor_in_normal_mode_view_op(self):
         for requires_grad in (True, False):
@@ -8542,17 +8660,45 @@
             self.assertTrue(out.is_leaf)
 
     def test_normal_tensor_inplace_output_in_inference_mode(self):
-        for requires_grad in (True, False):
-            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
-            a = s.clone()
+        def run_test(fn):
+            for requires_grad in (True, False):
+                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+                a = s.clone()
 
-            with torch.inference_mode():
-                a.add_(2)
+                with torch.inference_mode():
+                    fn(a)
+                    self.assertFalse(torch.is_inference(a))
+                    self.assertEqual(a.requires_grad, requires_grad)
+
+                    # inplace -> inplace
+                    fn(a)
+                    self.assertFalse(torch.is_inference(a))
+                    self.assertEqual(a.requires_grad, requires_grad)
+
+                    # inplace -> inplace -> view
+                    view_out = a.view(-1)
+                    self.assertFalse(torch.is_inference(view_out))
+                    self.assertEqual(view_out.requires_grad, requires_grad)
+        run_test(lambda x: x.add_(2))
+        run_test(lambda x: x.transpose_(0, 1))
+
+    def test_normal_tensor_inplace_output_in_normal_mode(self):
+        def run_test(fn):
+            for requires_grad in (True, False):
+                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+                a = s.clone()
+
+                with torch.inference_mode():
+                    fn(a)
+                    self.assertFalse(torch.is_inference(a))
+                    self.assertEqual(a.requires_grad, requires_grad)
+
+                fn(a)
                 self.assertFalse(torch.is_inference(a))
                 self.assertEqual(a.requires_grad, requires_grad)
 
                 # inplace -> inplace
-                a.add_(2)
+                fn(a)
                 self.assertFalse(torch.is_inference(a))
                 self.assertEqual(a.requires_grad, requires_grad)
 
@@ -8560,30 +8706,8 @@
                 view_out = a.view(-1)
                 self.assertFalse(torch.is_inference(view_out))
                 self.assertEqual(view_out.requires_grad, requires_grad)
-
-    def test_normal_tensor_inplace_output_in_normal_mode(self):
-        for requires_grad in (True, False):
-            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
-            a = s.clone()
-
-            with torch.inference_mode():
-                a.add_(2)
-                self.assertFalse(torch.is_inference(a))
-                self.assertEqual(a.requires_grad, requires_grad)
-
-            a.add_(2)
-            self.assertFalse(torch.is_inference(a))
-            self.assertEqual(a.requires_grad, requires_grad)
-
-            # inplace -> inplace
-            a.add_(2)
-            self.assertFalse(torch.is_inference(a))
-            self.assertEqual(a.requires_grad, requires_grad)
-
-            # inplace -> inplace -> view
-            view_out = a.view(-1)
-            self.assertFalse(torch.is_inference(view_out))
-            self.assertEqual(view_out.requires_grad, requires_grad)
+            run_test(lambda x: x.add_(2))
+            run_test(lambda x: x.transpose_(0, 1))
 
     def test_normal_tensor_view_output_in_inference_mode(self):
         for requires_grad in (True, False):
@@ -8718,37 +8842,43 @@
             self.assertEqual(tmp2.requires_grad, requires_grad)
 
     def test_inference_mode_handle_direct_view_on_rebase(self):
-        for requires_grad in (True, False):
-            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
-            a = s.clone()
+        def run_test(fn):
+            for requires_grad in (True, False):
+                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+                a = s.clone()
 
-            with torch.inference_mode():
-                view_out = a.view(-1)
+                with torch.inference_mode():
+                    view_out = a.view_as(a)
 
-            if requires_grad:
-                err_msg = "A view was created in inference mode and is being modified inplace"
-                with self.assertRaisesRegex(RuntimeError, err_msg):
-                    view_out.add_(2)
-                pass
-            else:
-                view_out.add_(2)
+                if requires_grad:
+                    err_msg = "A view was created in inference mode and is being modified inplace"
+                    with self.assertRaisesRegex(RuntimeError, err_msg):
+                        fn(view_out)
+                    pass
+                else:
+                    fn(view_out)
+        run_test(lambda x: x.add_(2))
+        run_test(lambda x: x.transpose_(0, 1))
 
     def test_inference_mode_handle_indirect_view_on_rebase(self):
-        for requires_grad in (True, False):
-            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
-            a = s.clone()
+        def run_test(fn):
+            for requires_grad in (True, False):
+                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+                a = s.clone()
 
-            with torch.inference_mode():
-                view_out = a.view(-1)
+                with torch.inference_mode():
+                    view_out = a.view(-1)
 
-            a.add_(2)
-            if requires_grad:
-                err_msg = "A view was created in inference mode and its base or another view "
-                with self.assertRaisesRegex(RuntimeError, err_msg):
+                fn(a)
+                if requires_grad:
+                    err_msg = "A view was created in inference mode and its base or another view "
+                    with self.assertRaisesRegex(RuntimeError, err_msg):
+                        view_out.grad_fn
+                    pass
+                else:
                     view_out.grad_fn
-                pass
-            else:
-                view_out.grad_fn
+        run_test(lambda x: x.add_(2))
+        run_test(lambda x: x.transpose_(0, 1))
 
 class TestMultithreadAutograd(TestCase):
     def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
diff --git a/test/test_nn.py b/test/test_nn.py
index db3e05b..a93aeee 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8215,7 +8215,7 @@
         test_pixel_shuffle_unshuffle_4D()
         test_pixel_shuffle_unshuffle_5D()
 
-    def test_elu_inplace_view(self):
+    def test_elu_inplace_on_view(self):
         v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
 
         def func(root):
@@ -8228,7 +8228,7 @@
         gradcheck(func, [v])
         gradgradcheck(func, [v])
 
-    def test_relu_inplace_view(self):
+    def test_relu_inplace_on_view(self):
         v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
 
         def func(root):
diff --git a/test/test_view_ops.py b/test/test_view_ops.py
index 6b02c2c..48f6672 100644
--- a/test/test_view_ops.py
+++ b/test/test_view_ops.py
@@ -449,6 +449,28 @@
             v[0, 1] = 0
             self.assertEqual(t[1, 0], v[0, 1])
 
+    def test_transpose_inplace_view(self, device):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.swapdims_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.swapaxes_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.transpose_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
     def test_t_view(self, device):
         t = torch.ones((5, 5), device=device)
         v = t.t()
@@ -457,6 +479,14 @@
         v[0, 1] = 0
         self.assertEqual(t[1, 0], v[0, 1])
 
+    def test_t_inplace_view(self, device):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.t_()
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
     def test_T_view(self, device):
         t = torch.ones((5, 5), device=device)
         v = t.T
@@ -480,6 +510,14 @@
         v[0, 1] = 0
         self.assertEqual(t, v._base)
 
+    def test_squeeze_inplace_view(self, device):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.squeeze_()
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t, v._base)
+
     def test_unsqueeze_view(self, device):
         t = torch.ones(5, 5, device=device)
         v = torch.unsqueeze(t, 1)
@@ -488,6 +526,14 @@
         v[0, 0, 1] = 0
         self.assertEqual(t[0, 1], v[0, 0, 1])
 
+    def test_unsqueeze_inplace_view(self, device):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.unsqueeze_(1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 0, 1] = 0
+        self.assertEqual(t[0, 1], v[0, 0, 1])
+
     def test_as_strided_view(self, device):
         t = torch.ones(5, 5, device=device)
         v = torch.as_strided(t, (25,), (1,))
@@ -496,6 +542,14 @@
         v[6] = 0
         self.assertEqual(t[1, 1], v[6])
 
+    def test_as_strided_inplace_view(self, device):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.as_strided_((25,), (1,))
+        self.assertTrue(self.is_view_of(t, v))
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
     def test_view_view(self, device):
         t = torch.ones(5, 5, device=device)
         v = t.view(25)
diff --git a/test/test_vmap.py b/test/test_vmap.py
index e9839f7..35b28db 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -2426,7 +2426,7 @@
 
 
     @allowVmapFallbackUsage
-    def test_inplace_view(self, device):
+    def test_inplace_on_view(self, device):
         leaf = torch.randn(4, 5, requires_grad=True)
 
         def func(leaf):