Revert "Refs and decompositions for index_{add,copy,select,fill} (#85002)"
This reverts commit 2f0b3de443dd8d4477d70c5a56fa14496d1eebe3.
Reverted https://github.com/pytorch/pytorch/pull/85002 on behalf of https://github.com/huydhn due to Broke trunk slow tests
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 101803c..f263c2c 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -1095,6 +1095,8 @@
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"index_select(): self and result must have the same scalar type");
+ TORCH_CHECK(dim == 0 || dim < self.dim(),
+ "index_select(): Indexing dim ", dim, " is out of bounds of tensor");
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
at::assert_no_overlap(result, index);
diff --git a/test/test_ops.py b/test/test_ops.py
index 1a41365..7384a81 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -245,9 +245,7 @@
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
prims.utils.compare_tensor_meta(a, b)
if getattr(op, 'validate_view_consistency', True) and not skip_view_consistency:
- msg = (f"The torch implementation {'returns' if b._is_view() else 'does not return'} "
- f"a view, while the reference {'does' if a._is_view() else 'does not'}")
- self.assertEqual(a._is_view(), b._is_view(), msg)
+ self.assertEqual(a._is_view(), b._is_view())
# Computes the dtype the more precise computatino would occur in
precise_dtype = torch.bool
@@ -1600,10 +1598,6 @@
'_refs.rfloordiv',
'_refs.rtruediv',
'_refs.rpow',
- # These should be tested with their out-of-place counterparts
- '_refs.index_add_',
- '_refs.index_copy_',
- '_refs.index_fill_',
}
not_in_decomp_table = {
@@ -1612,8 +1606,6 @@
'_refs.nn.functional.mse_loss',
'_refs.var',
'_refs.rsub',
- # duplicated due to efficiency concerns of the ref vs the decomp
- '_refs.index_add_',
# these are not aten ops?
'_refs.broadcast_shapes',
'_refs.broadcast_tensors',
@@ -1692,11 +1684,11 @@
op_impl = getattr(import_module(f"torch.{module_path}"), op_name)
if op in self.not_in_decomp_table:
- self.assertNotIn(op_impl, torch._decomp.decomposition_table.values(),
+ self.assertFalse(op_impl in torch._decomp.decomposition_table.values(),
f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()")
else:
- self.assertIn(op_impl, torch._decomp.decomposition_table.values(),
- f"Did not find {op} in torch._decomp.decomposition_table.values()")
+ self.assertTrue(op_impl in torch._decomp.decomposition_table.values(),
+ f"Did not find {op} in torch._decomp.decomposition_table.values()")
fake_skips = (
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index e00243f..402b317 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1559,20 +1559,6 @@
return ret / (length_h * length_w)
-@register_decomposition(aten.index_add_)
-def index_add_(x, dim, index, tensor, *, alpha=1):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- idx = (slice(None),) * dim + (index,)
- if alpha != 1:
- tensor = tensor * alpha
- torch.ops.aten.index_put_(x, idx, tensor, accumulate=True)
- return x
-
-
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
ndim = self.dim()
wrapped_dims = utils.canonicalize_dims(ndim, dims)
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index fbf3b24..e9ad580 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -70,13 +70,6 @@
"fill",
"floor",
"frac",
- "index_add",
- "index_add_",
- "index_copy",
- "index_copy_",
- "index_select",
- "index_fill",
- "index_fill_",
"isfinite",
"isinf",
"isnan",
@@ -2956,104 +2949,6 @@
)
-@register_decomposition(torch.ops.aten.index_copy)
-@out_wrapper()
-def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- return x.clone().index_copy_(dim, index, tensor)
-
-
-@register_decomposition(torch.ops.aten.index_copy_)
-def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- # Treat scalars as elements of \R^1
- y = x.unsqueeze(0) if x.ndim == 0 else x
- idx = (slice(None),) * dim + (index,)
- y[idx] = tensor
- return x
-
-
-@register_decomposition(torch.ops.aten.index_fill)
-def index_fill(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
-):
- return x.clone().index_fill_(dim, index, value) # type: ignore[arg-type]
-
-
-@register_decomposition(torch.ops.aten.index_fill_)
-def index_fill_(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
-):
- if isinstance(value, TensorLike):
- utils.check(
- value.ndim == 0,
- lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
- f"Got a tensor with {value.ndim} dimensions.",
- ) # type: ignore[arg-type]
- return x.clone().index_copy_(dim, index, value)
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- idx = (slice(None),) * dim + (index,)
- # Treat scalars as elements of \R^1
- y = x.unsqueeze(0) if x.ndim == 0 else x
- y[idx] = value # type: ignore[assignment]
- return x
-
-
-@register_decomposition(torch.ops.aten.index_add)
-@out_wrapper()
-def index_add(
- x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, alpha: float = 1
-):
- return x.clone().index_add_(dim, index, tensor, alpha=alpha)
-
-
-# The decomposition of this function dispatches to aten.index_put_ for efficiency
-# We cannot do that in Python, as torch.index_put_ does not support slice(None)s See
-# https://github.com/pytorch/pytorch/pull/85002#issuecomment-1248524492
-def index_add_(
- x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, alpha: float = 1
-):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if alpha != 1:
- python_type = utils.dtype_to_type(x.dtype)
- utils.check(
- utils.is_weakly_lesser_type(type(alpha), python_type),
- lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
- )
- tensor = prims.mul(tensor, alpha)
- # Treat scalars as elements of \R^1
- y = x.unsqueeze(0) if x.ndim == 0 else x
- idx = (slice(None),) * dim + (index,)
- y[idx] += tensor
- return x
-
-
-@register_decomposition(torch.ops.aten.index_select, disable_meta=True)
-@out_wrapper()
-def index_select(x: TensorLike, dim: int, index: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- # Treat scalars as elements of \R^1
- if x.ndim == 0:
- return x.unsqueeze(0)[index].squeeze(0).clone()
- idx = (slice(None),) * dim + (index,)
- return x[idx]
-
-
# Note: although squeeze is documented as having the out= kwarg it doesn't
@register_decomposition(torch.ops.aten.squeeze, disable_meta=True)
def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType:
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 83f6950..23cc637 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4073,13 +4073,13 @@
# https://github.com/pytorch/pytorch/issues/53352
def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
# target.index_select(dim, idx)
- select = "index_select" in op_info.name
+ select = op_info.name == "index_select"
# target.index_add(dim, idx, source, *, alpha=1)
- add = "index_add" in op_info.name
+ add = op_info.name == "index_add"
# target.index_copy(dim, idx, source)
- copy = "index_copy" in op_info.name
+ copy = op_info.name == "index_copy"
# target.index_fill(dim, idx, value)
- fill = "index_fill" in op_info.name
+ fill = op_info.name == "index_fill"
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -4090,13 +4090,7 @@
shapes = [(), (1,), (S, S)]
# extra parameter for add
- if add:
- if dtype == torch.bool:
- alphas = (True, False)
- else:
- alphas = (-1, 0, 2)
- else:
- alphas = (None,)
+ alphas = (-1, 0, 2) if add else (None,)
for shape, alpha in product(shapes, alphas):
t = make_arg(shape)
@@ -17497,45 +17491,6 @@
op=lambda self, condition, other: refs.where(condition, self, other),
supports_nvfuser=False,
),
- PythonRefInfo(
- "_refs.index_select",
- torch_opinfo_name="index_select",
- # empty_strided
- supports_nvfuser=False,
- skips=(
- # no _refs support for Tensor.__setitem__
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
- # Sample out= with a stride of zero. This _out operation checks that the input has no
- # inner overlap
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),)
- ),
- PythonRefInfo(
- "_refs.index_copy",
- torch_opinfo_name="index_copy",
- # empty_strided
- supports_nvfuser=False,
- skips=(
- # no _refs support for Tensor.__setitem__
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
- ),
- PythonRefInfo(
- "_refs.index_add",
- torch_opinfo_name="index_add",
- # empty_strided
- supports_nvfuser=False,
- skips=(
- # no _refs support for Tensor.__setitem__
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
- ),
- PythonRefInfo(
- "_refs.index_fill",
- torch_opinfo_name="index_fill",
- # empty_strided
- supports_nvfuser=False,
- skips=(
- # no _refs support for Tensor.__setitem__
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
- ),
#
# Test-related functions
#