Improve testing of inplace views (#59891)
Summary:
Partially addresses https://github.com/pytorch/pytorch/issues/49825 by improving the testing
- Rename some of the old tests that had "inplace_view" in their names, but actually mean "inplace_[update_]on_view" so there is no confusion with the naming
- Adds some tests in test_view_ops that verify basic behavior
- Add tests that creation meta is properly handled for no-grad, multi-output, and custom function cases
- Add test that verifies that in the cross dtype view case, the inplace views won't be accounted in the backward graph on rebase as mentioned in the issue.
- Update inference mode tests to also check in-place
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59891
Reviewed By: albanD
Differential Revision: D29272546
Pulled By: soulitzer
fbshipit-source-id: b12acf5f0e3f788167ebe268423cdb58481b56f6
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 47fc030..b0b0fcc 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -3407,7 +3407,7 @@
test_reduction(torch.cumprod, False)
test_reduction(torch.logcumsumexp, False, takes_dtype=False)
- def test_inplace_view_saved_output(self):
+ def test_inplace_on_view_saved_output(self):
# Test an in-place operation on a view in which the in-place op saves
# its output. Previously, this created a reference cycle.
dealloc = [0]
@@ -3426,7 +3426,7 @@
test()
self.assertEqual(dealloc[0], 1)
- def test_inplace_view_leaf_errors(self):
+ def test_inplace_on_view_leaf_errors(self):
# Issue #21875: Fail faster (when we try to modify the view vs. in backward())
x = torch.zeros(1, requires_grad=True)
y = x.view_as(x)
@@ -3436,7 +3436,7 @@
"an in-place operation."):
y.add_(1)
- def test_inplace_view_backward(self):
+ def test_inplace_on_view_backward(self):
# Issue #10532: Make sure that this does not raise RuntimeError.
net = nn.Sequential(
nn.InstanceNorm2d(2),
@@ -3465,7 +3465,7 @@
fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
self.assertEqual(fn.name(), "ThresholdBackwardBackward")
- def test_inplace_view_weak_grad_fn(self):
+ def test_inplace_on_view_weak_grad_fn(self):
# Issue 23502: Test that b's grad_fn is preserved.
a = torch.arange(10.0, requires_grad=True)
@@ -4919,7 +4919,7 @@
res.select(0, 0).copy_(grad)
return res, None
- fn_id_to_inplace_view_err_msg = {
+ fn_id_to_inplace_on_view_err_msg = {
"one_output": ("Output 0 of IdOneOutputBackward is a view and is being "
"modified inplace. This view was created inside a custom Function"),
"two_output": ("Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
@@ -4962,7 +4962,7 @@
a = torch.ones(2, dtype=dtype, requires_grad=True)
b = torch.ones(2, dtype=dtype, requires_grad=True)
- err_msg = fn_id_to_inplace_view_err_msg[fn_id]
+ err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
if not inplace or not output_is_a_view:
gradcheck(fn, (a, b), check_batched_grad=False)
@@ -4990,7 +4990,119 @@
self._do_test_autograd_simple_views_python(torch.double)
self._do_test_autograd_simple_views_python(torch.cdouble)
- def test_autograd_complex_views_python(self):
+ def test_autograd_inplace_views_creation_meta(self):
+ # Tests creation_meta properly handled for inplace views
+
+ class Func(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x.view_as(x)
+
+ @staticmethod
+ def backward(ctx, x):
+ return x
+ view_custom = Func.apply
+
+ def run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2):
+ # This test checks the behavior of inplace-view functions when
+ # the views are created in grad mode or not
+ base = torch.rand(2, 3, requires_grad=requires_grad).clone()
+ # 1. Create a view with `grad_mode=grad_mode_view`
+ with torch.set_grad_enabled(grad_mode_view):
+ if fn_type == "multi_view":
+ inp = base.unbind()[0]
+ elif fn_type == "custom" :
+ inp = view_custom(base)
+ else:
+ inp = base.view_as(base)
+
+ # 2. Perform inplace view with `grad_mode=grad_mode_iview`
+ with torch.set_grad_enabled(grad_mode_iview):
+ if error1 is not None:
+ with self.assertRaisesRegex(RuntimeError, error1):
+ fn(inp)
+ return
+ else:
+ # If error is None, check that runs without error
+ fn(inp)
+ # 3. Do inplace on the (new) view
+ if error2 is not None:
+ with self.assertRaisesRegex(RuntimeError, error2):
+ inp.add_(1)
+ else:
+ # If error is None, check that runs without error
+ inp.add_(1)
+
+ no_grad_err = "A view was created in no_grad mode"
+ multi_view_err = "function that returns multiple views"
+ custom_err = "view was created inside a custom Function"
+
+ def run_tests(fn):
+ for fn_type in ("normal", "multi_view", "custom"):
+ for grad_mode_view in (True, False):
+ for grad_mode_iview in (True, False):
+ for requires_grad in (True, False):
+ error1 = None # expected error when we do inplace_view on original view
+ error2 = None # expected error when we do inplace on the resulting view
+
+ if requires_grad:
+ if not grad_mode_view and grad_mode_iview:
+ error1 = no_grad_err
+ if not grad_mode_view and not grad_mode_iview:
+ error2 = no_grad_err
+
+ if fn_type == "multi_view":
+ if grad_mode_view and grad_mode_iview:
+ error1 = multi_view_err
+ if grad_mode_view and not grad_mode_iview:
+ error2 = multi_view_err
+
+ if fn_type == "custom":
+ if grad_mode_view and grad_mode_iview:
+ error1 = custom_err
+ if grad_mode_view and not grad_mode_iview:
+ error2 = custom_err
+
+ run_test(fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2)
+
+ # This list was created by logging gen_inplace_or_view_type.py
+ # detach_ is excluded for this test because it cannot be applied to
+ # views and thus does not return a view
+ run_tests(lambda v: v.as_strided_((1, 0), (2, 2)))
+ run_tests(lambda v: v.transpose_(0, 0))
+ run_tests(lambda v: v.t_())
+ run_tests(lambda v: v.squeeze_(0))
+ run_tests(lambda v: v.unsqueeze_(0))
+ run_tests(lambda v: v.swapdims_(0, 0))
+ run_tests(lambda v: v.swapaxes_(0, 0))
+
+ # TODO This is not the correct behavior -
+ # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627
+ def test_autograd_inplace_views_cross_dtype(self):
+ # This test is here to make sure that any change to this behavior is detected
+ # and not silent. The TODOs below mark the places with unexpected behavior.
+ a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
+ a = a_orig.clone()
+ b = torch.view_as_real(a)
+ b = b.transpose(0, 1)
+ b += 1
+ b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
+ non_inplace_grad = a_orig.grad
+
+ a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
+ a = a_orig.clone()
+ b = torch.view_as_real(a)
+ b.transpose_(0, 1)
+ b += 1
+ b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
+ inplace_grad = a_orig.grad
+
+ # TODO: this is a bug!
+ # once this is fixed, it should have the transpose removed:
+ # self.assertTrue(torch.allclose(non_inplace_grad, inplace_grad))
+ self.assertEqual(non_inplace_grad.T, inplace_grad)
+
+ def test_autograd_multiple_views_python(self):
# This is not necessarily the absolute correct behavior, but this is the current
# one. This test is here to make sure that any change to this behavior is detected
# and not silent. The TODOs below mark the places with unexpected behavior.
@@ -5032,7 +5144,7 @@
"Output 0 of ComplexViewBackward is a view and is being modified inplace"):
out += 1
- def test_autograd_inplace_views_python(self):
+ def test_autograd_python_custom_function_inplace(self):
# This is not necessarily the absolute correct behavior, but this is the current
# one. This test is here to make sure that any change to this behavior is detected
# and not silent. The TODOs below mark the places with unexpected behavior.
@@ -8078,7 +8190,7 @@
# gpu thread ReadyQueue
out.sum().backward()
- def test_inplace_view_backprop_base(self, device):
+ def test_inplace_on_view_backprop_base(self, device):
# modify view and back-prop through base
root = torch.randn(2, 2, device=device, requires_grad=True)
x = root.clone()
@@ -8087,7 +8199,7 @@
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
- def test_inplace_view_backprop_view_of_view(self, device):
+ def test_inplace_on_view_backprop_view_of_view(self, device):
# modify view and backprop through view-of-view
root = torch.randn(2, 2, device=device, requires_grad=True)
x = root.clone()
@@ -8097,7 +8209,7 @@
v2.sum().backward()
self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
- def test_inplace_view_of_view(self, device):
+ def test_inplace_on_view_of_view(self, device):
# modify view-of-view and backprop through base
root = torch.randn(2, 2, device=device, requires_grad=True)
x = root.clone()
@@ -8107,7 +8219,7 @@
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
- def test_inplace_view_then_no_grad(self, device):
+ def test_inplace_on_view_then_no_grad(self, device):
# Perform an in-place operation on a view of a non-leaf variable.
a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
b = a * 2
@@ -8120,7 +8232,7 @@
c.sum().backward()
- def test_inplace_view_gradcheck(self, device):
+ def test_inplace_on_view_gradcheck(self, device):
# gradcheck modifications to views
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8135,14 +8247,14 @@
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
gradgradcheck(func, (a, b), (go,))
- def test_inplace_view_multiple_outputs(self, device):
+ def test_inplace_on_view_multiple_outputs(self, device):
root = torch.arange(9., dtype=torch.double).reshape(3, 3).requires_grad_()
x = root.clone()
v1 = x.unbind()
with self.assertRaises(RuntimeError):
v1[0].mul_(2)
- def test_inplace_view_of_multiple_output_view(self, device):
+ def test_inplace_on_view_of_multiple_output_view(self, device):
a = torch.rand(10, dtype=torch.double, device=device, requires_grad=True).clone()
b = a.unbind(0)
c = b[0].view_as(b[0])
@@ -8156,7 +8268,7 @@
with self.assertRaises(RuntimeError):
c[0].mul_(2)
- def test_inplace_view_makes_base_require_grad(self, device):
+ def test_inplace_on_view_makes_base_require_grad(self, device):
# in-place modification to view makes base require grad
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8172,7 +8284,7 @@
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
gradgradcheck(func, (a, b), (go,))
- def test_inplace_view_backprop_view(self, device):
+ def test_inplace_on_view_backprop_view(self, device):
# modify view and backprop through view
a = torch.tensor([2., 5.], device=device, requires_grad=False)
b = torch.tensor([3.], device=device, requires_grad=True)
@@ -8181,7 +8293,7 @@
self.assertEqual(b.grad.tolist(), [5])
self.assertIsNone(a.grad)
- def test_inplace_view_modify_base(self, device):
+ def test_inplace_on_view_modify_base(self, device):
# Test that an in-place operation on a base that forced it to require
# grad also forces any previous views to require grad and backprop
# correctly
@@ -8199,7 +8311,7 @@
gradcheck(fn, [r])
gradgradcheck(fn, [r])
- def test_inplace_view_python(self, device):
+ def test_inplace_on_view_python(self, device):
# in-place modifications of Python-autograd created view
a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
@@ -8225,7 +8337,7 @@
go = torch.randn(a.size(), dtype=torch.double, device=device, requires_grad=True)
gradgradcheck(func, (a, b), (go,))
- def test_inplace_view_non_contig(self, device):
+ def test_inplace_on_view_non_contig(self, device):
root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True)
x = root.clone()
v1 = x.narrow(0, 0, 1)
@@ -8234,7 +8346,7 @@
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])
- def test_inplace_view_multi_output_unsafe(self, device):
+ def test_inplace_on_view_multi_output_unsafe(self, device):
for f in [lambda t: t.unsafe_split(1),
lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
lambda t: t.unsafe_chunk(3)]:
@@ -8244,7 +8356,7 @@
s1.mul_(s2)
s1.sum().backward()
- def test_inplace_view_multi_output_safe(self, device):
+ def test_inplace_on_view_multi_output_safe(self, device):
for f in [lambda t: t.split(1),
lambda t: t.split_with_sizes((1, 1, 1)),
lambda t: t.chunk(3)]:
@@ -8482,15 +8594,18 @@
self.assertFalse(func_out.requires_grad)
def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self):
- with torch.inference_mode():
+ @torch.inference_mode()
+ def run_test(fn):
for requires_grad in (True, False):
c = torch.ones(1, 2, 3, requires_grad=requires_grad)
- # after perform inplace operation, tensor is still
+ # after performing inplace operation, tensor is still
# an inference tensor
- c.add_(2)
+ fn(c)
self.assertTrue(torch.is_inference(c))
self.assertEqual(c.requires_grad, requires_grad)
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
def test_inference_mode_inf_tensor_in_inf_mode_view_op(self):
with torch.inference_mode():
@@ -8517,18 +8632,21 @@
self.assertTrue(func_out.is_leaf)
def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
- for requires_grad in (False, True):
- with torch.inference_mode():
- c = torch.ones(1, 2, 3, requires_grad=requires_grad)
+ def run_test(fn):
+ for requires_grad in (False, True):
+ with torch.inference_mode():
+ c = torch.ones(1, 2, 3, requires_grad=requires_grad)
- if requires_grad:
- # leaf variable that requires grad is being used in an inplace
- # operation when requires_grad=True
- pass
- else:
- err_msg = "Inplace update to inference tensor outside InferenceMode"
- with self.assertRaisesRegex(RuntimeError, err_msg):
- c.add_(2)
+ if requires_grad:
+ # leaf variable that requires grad is being used in an inplace
+ # operation when requires_grad=True
+ pass
+ else:
+ err_msg = "Inplace update to inference tensor outside InferenceMode"
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ fn(c)
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
def test_inference_mode_inf_tensor_in_normal_mode_view_op(self):
for requires_grad in (True, False):
@@ -8542,17 +8660,45 @@
self.assertTrue(out.is_leaf)
def test_normal_tensor_inplace_output_in_inference_mode(self):
- for requires_grad in (True, False):
- s = torch.ones(1, 2, 3, requires_grad=requires_grad)
- a = s.clone()
+ def run_test(fn):
+ for requires_grad in (True, False):
+ s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+ a = s.clone()
- with torch.inference_mode():
- a.add_(2)
+ with torch.inference_mode():
+ fn(a)
+ self.assertFalse(torch.is_inference(a))
+ self.assertEqual(a.requires_grad, requires_grad)
+
+ # inplace -> inplace
+ fn(a)
+ self.assertFalse(torch.is_inference(a))
+ self.assertEqual(a.requires_grad, requires_grad)
+
+ # inplace -> inplace -> view
+ view_out = a.view(-1)
+ self.assertFalse(torch.is_inference(view_out))
+ self.assertEqual(view_out.requires_grad, requires_grad)
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
+
+ def test_normal_tensor_inplace_output_in_normal_mode(self):
+ def run_test(fn):
+ for requires_grad in (True, False):
+ s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+ a = s.clone()
+
+ with torch.inference_mode():
+ fn(a)
+ self.assertFalse(torch.is_inference(a))
+ self.assertEqual(a.requires_grad, requires_grad)
+
+ fn(a)
self.assertFalse(torch.is_inference(a))
self.assertEqual(a.requires_grad, requires_grad)
# inplace -> inplace
- a.add_(2)
+ fn(a)
self.assertFalse(torch.is_inference(a))
self.assertEqual(a.requires_grad, requires_grad)
@@ -8560,30 +8706,8 @@
view_out = a.view(-1)
self.assertFalse(torch.is_inference(view_out))
self.assertEqual(view_out.requires_grad, requires_grad)
-
- def test_normal_tensor_inplace_output_in_normal_mode(self):
- for requires_grad in (True, False):
- s = torch.ones(1, 2, 3, requires_grad=requires_grad)
- a = s.clone()
-
- with torch.inference_mode():
- a.add_(2)
- self.assertFalse(torch.is_inference(a))
- self.assertEqual(a.requires_grad, requires_grad)
-
- a.add_(2)
- self.assertFalse(torch.is_inference(a))
- self.assertEqual(a.requires_grad, requires_grad)
-
- # inplace -> inplace
- a.add_(2)
- self.assertFalse(torch.is_inference(a))
- self.assertEqual(a.requires_grad, requires_grad)
-
- # inplace -> inplace -> view
- view_out = a.view(-1)
- self.assertFalse(torch.is_inference(view_out))
- self.assertEqual(view_out.requires_grad, requires_grad)
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
def test_normal_tensor_view_output_in_inference_mode(self):
for requires_grad in (True, False):
@@ -8718,37 +8842,43 @@
self.assertEqual(tmp2.requires_grad, requires_grad)
def test_inference_mode_handle_direct_view_on_rebase(self):
- for requires_grad in (True, False):
- s = torch.ones(1, 2, 3, requires_grad=requires_grad)
- a = s.clone()
+ def run_test(fn):
+ for requires_grad in (True, False):
+ s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+ a = s.clone()
- with torch.inference_mode():
- view_out = a.view(-1)
+ with torch.inference_mode():
+ view_out = a.view_as(a)
- if requires_grad:
- err_msg = "A view was created in inference mode and is being modified inplace"
- with self.assertRaisesRegex(RuntimeError, err_msg):
- view_out.add_(2)
- pass
- else:
- view_out.add_(2)
+ if requires_grad:
+ err_msg = "A view was created in inference mode and is being modified inplace"
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ fn(view_out)
+ pass
+ else:
+ fn(view_out)
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
def test_inference_mode_handle_indirect_view_on_rebase(self):
- for requires_grad in (True, False):
- s = torch.ones(1, 2, 3, requires_grad=requires_grad)
- a = s.clone()
+ def run_test(fn):
+ for requires_grad in (True, False):
+ s = torch.ones(1, 2, 3, requires_grad=requires_grad)
+ a = s.clone()
- with torch.inference_mode():
- view_out = a.view(-1)
+ with torch.inference_mode():
+ view_out = a.view(-1)
- a.add_(2)
- if requires_grad:
- err_msg = "A view was created in inference mode and its base or another view "
- with self.assertRaisesRegex(RuntimeError, err_msg):
+ fn(a)
+ if requires_grad:
+ err_msg = "A view was created in inference mode and its base or another view "
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ view_out.grad_fn
+ pass
+ else:
view_out.grad_fn
- pass
- else:
- view_out.grad_fn
+ run_test(lambda x: x.add_(2))
+ run_test(lambda x: x.transpose_(0, 1))
class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
diff --git a/test/test_nn.py b/test/test_nn.py
index db3e05b..a93aeee 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8215,7 +8215,7 @@
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()
- def test_elu_inplace_view(self):
+ def test_elu_inplace_on_view(self):
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
def func(root):
@@ -8228,7 +8228,7 @@
gradcheck(func, [v])
gradgradcheck(func, [v])
- def test_relu_inplace_view(self):
+ def test_relu_inplace_on_view(self):
v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True)
def func(root):
diff --git a/test/test_view_ops.py b/test/test_view_ops.py
index 6b02c2c..48f6672 100644
--- a/test/test_view_ops.py
+++ b/test/test_view_ops.py
@@ -449,6 +449,28 @@
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
+ def test_transpose_inplace_view(self, device):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.swapdims_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.swapaxes_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.transpose_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
def test_t_view(self, device):
t = torch.ones((5, 5), device=device)
v = t.t()
@@ -457,6 +479,14 @@
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
+ def test_t_inplace_view(self, device):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.t_()
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
def test_T_view(self, device):
t = torch.ones((5, 5), device=device)
v = t.T
@@ -480,6 +510,14 @@
v[0, 1] = 0
self.assertEqual(t, v._base)
+ def test_squeeze_inplace_view(self, device):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.squeeze_()
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t, v._base)
+
def test_unsqueeze_view(self, device):
t = torch.ones(5, 5, device=device)
v = torch.unsqueeze(t, 1)
@@ -488,6 +526,14 @@
v[0, 0, 1] = 0
self.assertEqual(t[0, 1], v[0, 0, 1])
+ def test_unsqueeze_inplace_view(self, device):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.unsqueeze_(1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 0, 1] = 0
+ self.assertEqual(t[0, 1], v[0, 0, 1])
+
def test_as_strided_view(self, device):
t = torch.ones(5, 5, device=device)
v = torch.as_strided(t, (25,), (1,))
@@ -496,6 +542,14 @@
v[6] = 0
self.assertEqual(t[1, 1], v[6])
+ def test_as_strided_inplace_view(self, device):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.as_strided_((25,), (1,))
+ self.assertTrue(self.is_view_of(t, v))
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
def test_view_view(self, device):
t = torch.ones(5, 5, device=device)
v = t.view(25)
diff --git a/test/test_vmap.py b/test/test_vmap.py
index e9839f7..35b28db 100644
--- a/test/test_vmap.py
+++ b/test/test_vmap.py
@@ -2426,7 +2426,7 @@
@allowVmapFallbackUsage
- def test_inplace_view(self, device):
+ def test_inplace_on_view(self, device):
leaf = torch.randn(4, 5, requires_grad=True)
def func(leaf):