Adds amax and amin references

Also extends reference testing to error inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76855
Approved by: https://github.com/mruberry
diff --git a/test/test_ops.py b/test/test_ops.py
index 766b3c5..cc9f718 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -342,7 +342,7 @@
     @onlyNativeDeviceTypes
     @ops(python_ref_db)
     def test_python_reference_consistency(self, device, dtype, op):
-        for sample in op.reference_inputs(device, dtype, requires_grad=False):
+        for sample in op.torch_opinfo.reference_inputs(device, dtype, requires_grad=False):
             actual = op(sample.input, *sample.args, **sample.kwargs)
             expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
 
@@ -357,7 +357,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
+    @ops([op for op in ops_and_refs if op.error_inputs_func is not None], dtypes=OpDTypes.none)
     def test_errors(self, device, op):
         error_inputs = op.error_inputs(device)
         for ei in error_inputs:
diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py
index 9283df8..7430a9a 100644
--- a/torch/_prims/utils.py
+++ b/torch/_prims/utils.py
@@ -677,5 +677,6 @@
     if dims is None:
         return tuple(range(len(shape)))
     dims = tuple(canonicalize_idx(len(shape), idx) for idx in dims)
-    assert len(dims) == len(set(dims)), "duplicate value in dims"
+    if len(dims) != len(set(dims)):
+        raise RuntimeError("duplicate value in the list of dims")
     return dims
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 8e2d510..198ad78 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -23,7 +23,7 @@
 # Experimental module containing prototype Python references for existing
 #   PyTorch operations.
 
-all = [
+__all__ = [
     #
     # Elementwise Unary References
     #
@@ -119,7 +119,9 @@
     #
     # Reduction ops
     #
-    "sum",  # TODO: add opinfo
+    "sum",
+    "amax",
+    "amin",
     #
     # View & Shape Ops
     #
@@ -616,7 +618,7 @@
 )
 
 
-def _make_elementwise_binary_reference(prim: Callable, *, type_promotion) -> Callable:
+def _make_elementwise_binary_reference(prim: Callable, *, type_promotion, wrap_scalars=False) -> Callable:
     def _ref(
         a: Union[Tensor, NumberType],
         b: Union[Tensor, NumberType],
@@ -629,7 +631,10 @@
 
         # Special-cases Number x Number case
         if isinstance(a, Number) and isinstance(b, Number):
-            a, b = utils.wrap_scalars(a, b)
+            if wrap_scalars:
+                a, b = utils.wrap_scalars(a, b)
+            else:
+                raise RuntimeError("got two scalar arguments, while expected at least one TensorLike")
 
         # Handles type promotion
         computation_dtype, result_dtype = _elementwise_dtypes(
@@ -752,10 +757,6 @@
     assert isinstance(b, (TensorLike, Number))
     assert out is None or isinstance(out, TensorLike)
 
-    # Special-cases Number x Number case
-    if isinstance(a, Number) and isinstance(b, Number):
-        a, b = utils.wrap_scalars(a, b)
-
     # Handles type promotion
     dtype = utils.get_higher_dtype(a, b)
     assert dtype is not None
@@ -820,7 +821,8 @@
 
 # TODO: add docstring
 mul = _make_elementwise_binary_reference(
-    prims.mul, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
+    prims.mul, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH,
+    wrap_scalars=True
 )
 
 # TODO: add docstring
@@ -890,7 +892,8 @@
 
 # TODO: add docstring
 true_divide = _make_elementwise_binary_reference(
-    prims.div, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+    prims.div, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+    wrap_scalars=True
 )
 
 #
@@ -982,10 +985,9 @@
         dims = (dims,)  # type: ignore[assignment]
     dims = utils.reduction_dims(a.shape, dims)
     if not has_identity:
-        valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims)  # type: ignore[operator]
-        assert (
-            valid_shape
-        ), "reducing over zero-size dimension for reduction operation without identity"
+        valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims)
+        if not valid_shape:
+            raise RuntimeError("reducing over zero-size dimension for reduction operation without identity")
     # even though some reductions, like amin or amax, don't strictly require type promotion,
     # all the math ops (including comparisons) are still defined only for a computation type,
     # so promotion will still happen. We are doing it explicitly here
@@ -1003,12 +1005,12 @@
             if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
                 if out.dtype != a.dtype:
                     raise RuntimeError(
-                        "out dtype and output type of reduction must match"
+                        "Expected the dtype for input and out to match"
                     )
             elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL:
                 if out.dtype != torch.bool:
                     raise RuntimeError(
-                        "out dtype and output type of reduction must match"
+                        "Expected the dtype for input and out to match"
                     )
         out = _maybe_resize_out(out, result.shape)
         return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]
@@ -1032,7 +1034,7 @@
             dtype = torch.int64
         else:
             dtype = a.dtype
-    # sum reduces over all dimensions if dim=() is passed
+    # reduces over all dimensions if dim=() is passed
     if dim == () or dim == []:
         dim = None
     return _reduction(
@@ -1045,6 +1047,48 @@
         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
     )
 
+def amin(
+    a: Tensor,
+    dim: Union[Optional[int], Optional[List[int]]] = None,
+    keepdim: bool = False,
+    *,
+    out: Optional[Tensor] = None
+):
+    # reduces over all dimensions if dim=() is passed
+    if dim == () or dim == []:
+        dim = None
+    return _reduction(
+        a,
+        prims.amin,
+        dims=dim,
+        keepdims=keepdim,
+        dtype=None,
+        out=out,
+        has_identity=False,
+        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    )
+
+def amax(
+    a: Tensor,
+    dim: Union[Optional[int], Optional[List[int]]] = None,
+    keepdim: bool = False,
+    *,
+    out: Optional[Tensor] = None
+):
+    # reduces over all dimensions if dim=() is passed
+    if dim == () or dim == []:
+        dim = None
+    return _reduction(
+        a,
+        prims.amax,
+        dims=dim,
+        keepdims=keepdim,
+        dtype=None,
+        out=out,
+        has_identity=False,
+        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
+    )
+
 
 def cat(
     tensors: TensorSequenceType, dim: int = 0, out: TensorLikeType = None
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b7d0148..cf9a16a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4026,9 +4026,9 @@
 
     # Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
     shape = (S, 0, S)
-    err_msg_amax_amin = "Specify the reduction dim with the 'dim' argument."
+    err_msg_amax_amin = "reduction"
     err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
-    if op_info.name in ['amax', 'amin']:
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
         yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
     elif op_info.name in ['aminmax']:
         yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
@@ -4042,9 +4042,9 @@
                      error_regex=err_msg1)
 
     # Error Inputs for repeated 'dim'
-    if op_info.name in ['amax', 'amin']:
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
         dims = [(0, 0), (0, -4)]
-        err_msg2 = "dim 0 appears multiple times in the list of dims"
+        err_msg2 = "in the list of dims"
         x = torch.randn(S, S, S, S, device=device)
         for dim in dims:
             yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2)
@@ -4058,7 +4058,7 @@
     err_msg_amax_amin2 = "Expected the dtype for input and out to match"
     err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"
 
-    if op_info.name in ['amax', 'amin']:
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
         yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
                          error_regex=err_msg_amax_amin2)
     elif op_info.name in ['aminmax']:
@@ -4066,9 +4066,11 @@
                          error_regex=err_msg_aminmax2)
 
     # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim
-    err_msg3 = "Expected reduction dim 1 to have non-zero size"
+    err_msg3 = "reduction"
+    # FIXME: eager and ref impl throw different types of errors
+    error_type = IndexError if 'refs' not in op_info.name else RuntimeError
     yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}),
-                     error_type=IndexError, error_regex=err_msg3)
+                     error_type=error_type, error_regex=err_msg3)
 
 def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs):
     test_cases: Tuple[tuple, dict] = (  # type: ignore[assignment]
@@ -16655,7 +16657,7 @@
         dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
         ref=reference_reduction_numpy(np.amax),
         skips=(
-            # FIXME: sum reduces all dimensions when dim=[]
+            # FIXME: reduces all dimensions when dim=[]
             DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
             DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
         ),
@@ -16667,7 +16669,7 @@
         dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
         ref=reference_reduction_numpy(np.amin),
         skips=(
-            # FIXME: sum reduces all dimensions when dim=[]
+            # FIXME: reduces all dimensions when dim=[]
             DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
             DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
         ),
@@ -17892,10 +17894,14 @@
     ElementwiseBinaryPythonRefInfo(
         "_refs.maximum",
         torch_opinfo_name="maximum",
+        supports_rhs_python_scalar=True,
+        supports_one_python_scalar=True
     ),
     ElementwiseBinaryPythonRefInfo(
         "_refs.minimum",
         torch_opinfo_name="minimum",
+        supports_rhs_python_scalar=True,
+        supports_one_python_scalar=True
     ),
     ElementwiseBinaryPythonRefInfo(
         "_refs.mul",
@@ -17976,7 +17982,18 @@
         "_refs.sum",
         torch_opinfo_name="sum",
         supports_out=True
+    ),
+
+    ReductionPythonRefInfo(
+        "_refs.amin",
+        torch_opinfo_name="amin",
+    ),
+
+    ReductionPythonRefInfo(
+        "_refs.amax",
+        torch_opinfo_name="amax",
     )
+
 ]
 
 # Common operator groupings