Adds amax and amin references
Also extends reference testing to error inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76855
Approved by: https://github.com/mruberry
diff --git a/test/test_ops.py b/test/test_ops.py
index 766b3c5..cc9f718 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -342,7 +342,7 @@
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_reference_consistency(self, device, dtype, op):
- for sample in op.reference_inputs(device, dtype, requires_grad=False):
+ for sample in op.torch_opinfo.reference_inputs(device, dtype, requires_grad=False):
actual = op(sample.input, *sample.args, **sample.kwargs)
expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
@@ -357,7 +357,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
+ @ops([op for op in ops_and_refs if op.error_inputs_func is not None], dtypes=OpDTypes.none)
def test_errors(self, device, op):
error_inputs = op.error_inputs(device)
for ei in error_inputs:
diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py
index 9283df8..7430a9a 100644
--- a/torch/_prims/utils.py
+++ b/torch/_prims/utils.py
@@ -677,5 +677,6 @@
if dims is None:
return tuple(range(len(shape)))
dims = tuple(canonicalize_idx(len(shape), idx) for idx in dims)
- assert len(dims) == len(set(dims)), "duplicate value in dims"
+ if len(dims) != len(set(dims)):
+ raise RuntimeError("duplicate value in the list of dims")
return dims
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 8e2d510..198ad78 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -23,7 +23,7 @@
# Experimental module containing prototype Python references for existing
# PyTorch operations.
-all = [
+__all__ = [
#
# Elementwise Unary References
#
@@ -119,7 +119,9 @@
#
# Reduction ops
#
- "sum", # TODO: add opinfo
+ "sum",
+ "amax",
+ "amin",
#
# View & Shape Ops
#
@@ -616,7 +618,7 @@
)
-def _make_elementwise_binary_reference(prim: Callable, *, type_promotion) -> Callable:
+def _make_elementwise_binary_reference(prim: Callable, *, type_promotion, wrap_scalars=False) -> Callable:
def _ref(
a: Union[Tensor, NumberType],
b: Union[Tensor, NumberType],
@@ -629,7 +631,10 @@
# Special-cases Number x Number case
if isinstance(a, Number) and isinstance(b, Number):
- a, b = utils.wrap_scalars(a, b)
+ if wrap_scalars:
+ a, b = utils.wrap_scalars(a, b)
+ else:
+ raise RuntimeError("got two scalar arguments, while expected at least one TensorLike")
# Handles type promotion
computation_dtype, result_dtype = _elementwise_dtypes(
@@ -752,10 +757,6 @@
assert isinstance(b, (TensorLike, Number))
assert out is None or isinstance(out, TensorLike)
- # Special-cases Number x Number case
- if isinstance(a, Number) and isinstance(b, Number):
- a, b = utils.wrap_scalars(a, b)
-
# Handles type promotion
dtype = utils.get_higher_dtype(a, b)
assert dtype is not None
@@ -820,7 +821,8 @@
# TODO: add docstring
mul = _make_elementwise_binary_reference(
- prims.mul, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
+ prims.mul, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH,
+ wrap_scalars=True
)
# TODO: add docstring
@@ -890,7 +892,8 @@
# TODO: add docstring
true_divide = _make_elementwise_binary_reference(
- prims.div, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ prims.div, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ wrap_scalars=True
)
#
@@ -982,10 +985,9 @@
dims = (dims,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dims)
if not has_identity:
- valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims) # type: ignore[operator]
- assert (
- valid_shape
- ), "reducing over zero-size dimension for reduction operation without identity"
+ valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims)
+ if not valid_shape:
+ raise RuntimeError("reducing over zero-size dimension for reduction operation without identity")
# even though some reductions, like amin or amax, don't strictly require type promotion,
# all the math ops (including comparisons) are still defined only for a computation type,
# so promotion will still happen. We are doing it explicitly here
@@ -1003,12 +1005,12 @@
if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
if out.dtype != a.dtype:
raise RuntimeError(
- "out dtype and output type of reduction must match"
+ "Expected the dtype for input and out to match"
)
elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL:
if out.dtype != torch.bool:
raise RuntimeError(
- "out dtype and output type of reduction must match"
+ "Expected the dtype for input and out to match"
)
out = _maybe_resize_out(out, result.shape)
return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type]
@@ -1032,7 +1034,7 @@
dtype = torch.int64
else:
dtype = a.dtype
- # sum reduces over all dimensions if dim=() is passed
+ # reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
return _reduction(
@@ -1045,6 +1047,48 @@
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
)
+def amin(
+ a: Tensor,
+ dim: Union[Optional[int], Optional[List[int]]] = None,
+ keepdim: bool = False,
+ *,
+ out: Optional[Tensor] = None
+):
+ # reduces over all dimensions if dim=() is passed
+ if dim == () or dim == []:
+ dim = None
+ return _reduction(
+ a,
+ prims.amin,
+ dims=dim,
+ keepdims=keepdim,
+ dtype=None,
+ out=out,
+ has_identity=False,
+ output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
+ )
+
+def amax(
+ a: Tensor,
+ dim: Union[Optional[int], Optional[List[int]]] = None,
+ keepdim: bool = False,
+ *,
+ out: Optional[Tensor] = None
+):
+ # reduces over all dimensions if dim=() is passed
+ if dim == () or dim == []:
+ dim = None
+ return _reduction(
+ a,
+ prims.amax,
+ dims=dim,
+ keepdims=keepdim,
+ dtype=None,
+ out=out,
+ has_identity=False,
+ output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
+ )
+
def cat(
tensors: TensorSequenceType, dim: int = 0, out: TensorLikeType = None
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b7d0148..cf9a16a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4026,9 +4026,9 @@
# Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
shape = (S, 0, S)
- err_msg_amax_amin = "Specify the reduction dim with the 'dim' argument."
+ err_msg_amax_amin = "reduction"
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
- if op_info.name in ['amax', 'amin']:
+ if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
elif op_info.name in ['aminmax']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
@@ -4042,9 +4042,9 @@
error_regex=err_msg1)
# Error Inputs for repeated 'dim'
- if op_info.name in ['amax', 'amin']:
+ if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
dims = [(0, 0), (0, -4)]
- err_msg2 = "dim 0 appears multiple times in the list of dims"
+ err_msg2 = "in the list of dims"
x = torch.randn(S, S, S, S, device=device)
for dim in dims:
yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2)
@@ -4058,7 +4058,7 @@
err_msg_amax_amin2 = "Expected the dtype for input and out to match"
err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"
- if op_info.name in ['amax', 'amin']:
+ if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
error_regex=err_msg_amax_amin2)
elif op_info.name in ['aminmax']:
@@ -4066,9 +4066,11 @@
error_regex=err_msg_aminmax2)
# Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim
- err_msg3 = "Expected reduction dim 1 to have non-zero size"
+ err_msg3 = "reduction"
+ # FIXME: eager and ref impl throw different types of errors
+ error_type = IndexError if 'refs' not in op_info.name else RuntimeError
yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}),
- error_type=IndexError, error_regex=err_msg3)
+ error_type=error_type, error_regex=err_msg3)
def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs):
test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment]
@@ -16655,7 +16657,7 @@
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
ref=reference_reduction_numpy(np.amax),
skips=(
- # FIXME: sum reduces all dimensions when dim=[]
+ # FIXME: reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
),
@@ -16667,7 +16669,7 @@
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
ref=reference_reduction_numpy(np.amin),
skips=(
- # FIXME: sum reduces all dimensions when dim=[]
+ # FIXME: reduces all dimensions when dim=[]
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
),
@@ -17892,10 +17894,14 @@
ElementwiseBinaryPythonRefInfo(
"_refs.maximum",
torch_opinfo_name="maximum",
+ supports_rhs_python_scalar=True,
+ supports_one_python_scalar=True
),
ElementwiseBinaryPythonRefInfo(
"_refs.minimum",
torch_opinfo_name="minimum",
+ supports_rhs_python_scalar=True,
+ supports_one_python_scalar=True
),
ElementwiseBinaryPythonRefInfo(
"_refs.mul",
@@ -17976,7 +17982,18 @@
"_refs.sum",
torch_opinfo_name="sum",
supports_out=True
+ ),
+
+ ReductionPythonRefInfo(
+ "_refs.amin",
+ torch_opinfo_name="amin",
+ ),
+
+ ReductionPythonRefInfo(
+ "_refs.amax",
+ torch_opinfo_name="amax",
)
+
]
# Common operator groupings