[primTorch] Add prim and ref for as_strided_scatter (#88426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88426
Approved by: https://github.com/mruberry
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index cb17591..b6d0214 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1981,6 +1981,7 @@
"aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
"aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
"aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
+ "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
}:
pass
else:
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 3d4d44d..a7cc65e 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -151,6 +151,10 @@
"transpose",
"view_of",
#
+ # Functionalized view mutations
+ #
+ "as_strided_scatter",
+ #
# Shape prims
#
"collapse",
@@ -1795,6 +1799,53 @@
)
#
+# Functionalized view mutations
+#
+
+
+def _as_strided_scatter_meta(
+ input: TensorLikeType,
+ src: TensorLikeType,
+ size: ShapeType,
+ stride: StrideType,
+ storage_offset: int,
+) -> TensorLikeType:
+ utils.validate_shape(size)
+ utils.validate_strides(stride)
+
+ required_size = utils.compute_required_storage_length(size, stride, storage_offset)
+ utils.check(
+ input.numel() >= required_size,
+ lambda: (
+ f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
+ f" and itemsize {input.element_size()} requiring a storage size of "
+ f"{required_size * input.element_size()} are out of bounds "
+ f"for storage of size {input.numel() * input.element_size()}"
+ ),
+ )
+ utils.check(
+ utils.is_same_shape(src.shape, size),
+ lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
+ )
+
+ return _clone_meta(input)
+
+
+_as_strided_scatter_doc = """
+ Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
+ ``out.as_strided(size, stride, storage_offset).copy_(src)``.
+"""
+
+as_strided_scatter = _make_prim(
+ schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
+ meta=_as_strided_scatter_meta,
+ impl_aten=torch.as_strided_scatter,
+ return_type=RETURN_TYPE.NEW,
+ doc=_as_strided_scatter_doc,
+)
+
+
+#
# Shape operations
#
def collapse(a: Tensor, start: int, end: int) -> Tensor:
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index dd90d25..b34f109 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -1483,6 +1483,44 @@
return correction
+def compute_required_storage_length(
+ shape: ShapeType, strides: StrideType, storage_offset: int
+) -> int:
+ """Computes the minimum storage size to hold the given tensor geometry.
+
+ Example
+ =======
+
+ This is the size of a newly allocated tensor's storage, in units of elements
+
+ >>> t = torch.empty((10, 20))
+ >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
+ 200
+
+ >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
+ >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
+ >>> size == t.storage().size()
+ True
+
+ A valid tensor may have a larger storage size, but never smaller
+
+ >>> slice = torch.empty(100)[20:40]
+ >>> slice.storage().size()
+ 100
+
+ >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
+ 40
+
+ """
+ # Short-circuits if the shape has no elements
+ if reduce(operator.mul, shape, 1) == 0:
+ return 0
+
+ max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
+ # +1 to account for the first element which offsets are taken from
+ return 1 + storage_offset + max_offset
+
+
def check_in_bounds_for_storage(
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
):
@@ -1490,17 +1528,8 @@
Determines if the given shape, strides, and offset are valid for the given storage.
"""
- # Short-circuits if the shape has no elements
- if reduce(operator.mul, shape) == 0:
- return
-
- length = a.size() - storage_offset
- max_offset = 0
- for x, y in zip(shape, strides):
- max_offset = max_offset + (x - 1) * y
-
- if max_offset >= length:
- required_length = max_offset + storage_offset
+ required_length = compute_required_storage_length(shape, strides, storage_offset)
+ if a.size() < required_length:
msg = (
"Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, "
"which requires a storage of size {4}".format(
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 354ef9c..f06f5ba 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -2520,6 +2520,18 @@
return prims.as_strided(a, size, stride, storage_offset_int)
+@register_decomposition(torch.ops.aten.as_strided_scatter)
+def as_strided_scatter(
+ input: TensorLikeType,
+ src: TensorLikeType,
+ size: ShapeType,
+ stride: StrideType,
+ storage_offset: Optional[int] = None,
+) -> TensorLikeType:
+ storage_offset_int = 0 if storage_offset is None else storage_offset
+ return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
+
+
def broadcast_shapes(*shapes) -> ShapeType:
return torch.Size(_broadcast_shapes(*shapes))
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 527c079..c01f476 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -278,6 +278,7 @@
# input shape, output shape, output stride, output storage offset
test_cases = [
+ ((1,), (), (), 0),
((1,), (1,), (1,), 0),
((3, 3), (2, 2), (1, 2), 0),
((3, 3), (2, 2), (1, 2), 1),
@@ -293,6 +294,7 @@
input_src = make_arg(output_shape)
yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
+
def error_inputs_as_strided_scatter(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
@@ -10821,8 +10823,6 @@
'test_non_standard_bool_values'),
)),
OpInfo('as_strided_scatter',
- op=lambda x, src, size, stride, storage_offset=0:
- torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
@@ -18425,6 +18425,11 @@
),
),
PythonRefInfo(
+ "_refs.as_strided_scatter",
+ torch_opinfo_name="as_strided_scatter",
+ supports_nvfuser=False,
+ ),
+ PythonRefInfo(
"_refs.broadcast_shapes",
torch_opinfo_name="broadcast_shapes",
supports_nvfuser=False,