[primTorch] Add prim and ref for as_strided_scatter (#88426)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88426
Approved by: https://github.com/mruberry
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index cb17591..b6d0214 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1981,6 +1981,7 @@
             "aten::copy_",  # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64  # noqa: B950
             "aten::constant_pad_nd",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32  # noqa: B950
             "aten::rot90",  # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32  # noqa: B950
+            "aten::as_strided_scatter",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32  # noqa: B950
         }:
             pass
         else:
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 3d4d44d..a7cc65e 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -151,6 +151,10 @@
     "transpose",
     "view_of",
     #
+    # Functionalized view mutations
+    #
+    "as_strided_scatter",
+    #
     # Shape prims
     #
     "collapse",
@@ -1795,6 +1799,53 @@
 )
 
 #
+# Functionalized view mutations
+#
+
+
+def _as_strided_scatter_meta(
+    input: TensorLikeType,
+    src: TensorLikeType,
+    size: ShapeType,
+    stride: StrideType,
+    storage_offset: int,
+) -> TensorLikeType:
+    utils.validate_shape(size)
+    utils.validate_strides(stride)
+
+    required_size = utils.compute_required_storage_length(size, stride, storage_offset)
+    utils.check(
+        input.numel() >= required_size,
+        lambda: (
+            f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
+            f" and itemsize {input.element_size()} requiring a storage size of "
+            f"{required_size * input.element_size()} are out of bounds "
+            f"for storage of size {input.numel() * input.element_size()}"
+        ),
+    )
+    utils.check(
+        utils.is_same_shape(src.shape, size),
+        lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
+    )
+
+    return _clone_meta(input)
+
+
+_as_strided_scatter_doc = """
+    Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
+    ``out.as_strided(size, stride, storage_offset).copy_(src)``.
+"""
+
+as_strided_scatter = _make_prim(
+    schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
+    meta=_as_strided_scatter_meta,
+    impl_aten=torch.as_strided_scatter,
+    return_type=RETURN_TYPE.NEW,
+    doc=_as_strided_scatter_doc,
+)
+
+
+#
 # Shape operations
 #
 def collapse(a: Tensor, start: int, end: int) -> Tensor:
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index dd90d25..b34f109 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -1483,6 +1483,44 @@
     return correction
 
 
+def compute_required_storage_length(
+    shape: ShapeType, strides: StrideType, storage_offset: int
+) -> int:
+    """Computes the minimum storage size to hold the given tensor geometry.
+
+    Example
+    =======
+
+    This is the size of a newly allocated tensor's storage, in units of elements
+
+    >>> t = torch.empty((10, 20))
+    >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
+    200
+
+    >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
+    >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
+    >>> size == t.storage().size()
+    True
+
+    A valid tensor may have a larger storage size, but never smaller
+
+    >>> slice = torch.empty(100)[20:40]
+    >>> slice.storage().size()
+    100
+
+    >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
+    40
+
+    """
+    # Short-circuits if the shape has no elements
+    if reduce(operator.mul, shape, 1) == 0:
+        return 0
+
+    max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
+    # +1 to account for the first element which offsets are taken from
+    return 1 + storage_offset + max_offset
+
+
 def check_in_bounds_for_storage(
     a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
 ):
@@ -1490,17 +1528,8 @@
     Determines if the given shape, strides, and offset are valid for the given storage.
     """
 
-    # Short-circuits if the shape has no elements
-    if reduce(operator.mul, shape) == 0:
-        return
-
-    length = a.size() - storage_offset
-    max_offset = 0
-    for x, y in zip(shape, strides):
-        max_offset = max_offset + (x - 1) * y
-
-    if max_offset >= length:
-        required_length = max_offset + storage_offset
+    required_length = compute_required_storage_length(shape, strides, storage_offset)
+    if a.size() < required_length:
         msg = (
             "Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, "
             "which requires a storage of size {4}".format(
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 354ef9c..f06f5ba 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -2520,6 +2520,18 @@
     return prims.as_strided(a, size, stride, storage_offset_int)
 
 
+@register_decomposition(torch.ops.aten.as_strided_scatter)
+def as_strided_scatter(
+    input: TensorLikeType,
+    src: TensorLikeType,
+    size: ShapeType,
+    stride: StrideType,
+    storage_offset: Optional[int] = None,
+) -> TensorLikeType:
+    storage_offset_int = 0 if storage_offset is None else storage_offset
+    return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
+
+
 def broadcast_shapes(*shapes) -> ShapeType:
     return torch.Size(_broadcast_shapes(*shapes))
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 527c079..c01f476 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -278,6 +278,7 @@
 
     # input shape, output shape, output stride, output storage offset
     test_cases = [
+        ((1,), (), (), 0),
         ((1,), (1,), (1,), 0),
         ((3, 3), (2, 2), (1, 2), 0),
         ((3, 3), (2, 2), (1, 2), 1),
@@ -293,6 +294,7 @@
         input_src = make_arg(output_shape)
         yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
 
+
 def error_inputs_as_strided_scatter(op_info, device, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
 
@@ -10821,8 +10823,6 @@
                             'test_non_standard_bool_values'),
            )),
     OpInfo('as_strided_scatter',
-           op=lambda x, src, size, stride, storage_offset=0:
-               torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
            supports_out=False,
            supports_forward_ad=True,
@@ -18425,6 +18425,11 @@
         ),
     ),
     PythonRefInfo(
+        "_refs.as_strided_scatter",
+        torch_opinfo_name="as_strided_scatter",
+        supports_nvfuser=False,
+    ),
+    PythonRefInfo(
         "_refs.broadcast_shapes",
         torch_opinfo_name="broadcast_shapes",
         supports_nvfuser=False,