masked argmin and argmax

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74525

Approved by: https://github.com/cpuhrsch
diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py
index 24f984f..6f962d0 100644
--- a/torch/_masked/__init__.py
+++ b/torch/_masked/__init__.py
@@ -163,6 +163,8 @@
         prod=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         amin=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         amax=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
+        argmin=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')),
+        argmax=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')),
         mean=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
@@ -227,6 +229,8 @@
         prod='product',
         amax='maximum',
         amin='minimum',
+        argmax='argmax',
+        argmin='argmin',
         mean='mean',
         norm='norm',
         var='variance',
@@ -345,12 +349,12 @@
         return torch.tensor(0, dtype=dtype, device=device)
     elif op_name == 'prod':
         return torch.tensor(1, dtype=dtype, device=device)
-    elif op_name == 'amax':
+    elif op_name in {'amax', 'argmax'}:
         if torch.is_floating_point(input):
             return torch.tensor(-torch.inf, dtype=dtype, device=device)
         elif torch.is_signed(input) or dtype == torch.uint8:
             return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
-    elif op_name == 'amin':
+    elif op_name in {'amin', 'argmin'}:
         if torch.is_floating_point(input):
             return torch.tensor(torch.inf, dtype=dtype, device=device)
         elif torch.is_signed(input) or dtype == torch.uint8:
@@ -621,7 +625,7 @@
     """Return output mask of masked operation applied to given arguments.
     """
     if callable(op):
-        is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm', 'var', 'std'}
+        is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'argmax', 'argmin', 'mean', 'norm', 'var', 'std'}
         is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
         if is_reduction:
             if op.__name__ == 'norm':
@@ -780,6 +784,50 @@
 
 
 @_apply_docstring_templates
+def argmax(input: Tensor,
+           dim: int = None,
+           *,
+           keepdim: Optional[bool] = False,
+           dtype: Optional[DType] = None,
+           mask: Optional[Tensor] = None) -> Tensor:
+    """\
+{reduction_signature}
+{reduction_descr}
+{reduction_identity_dtype}
+{reduction_args}
+{reduction_example}"""
+    if dtype is None:
+        dtype = input.dtype
+    mask_input = _combine_input_and_mask(argmax, input, mask)
+    if input.layout == torch.strided:
+        return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
+    else:
+        raise ValueError(f'masked argmax expects strided tensor (got {input.layout} tensor)')
+
+
+@_apply_docstring_templates
+def argmin(input: Tensor,
+           dim: int = None,
+           *,
+           keepdim: Optional[bool] = False,
+           dtype: Optional[DType] = None,
+           mask: Optional[Tensor] = None) -> Tensor:
+    """\
+{reduction_signature}
+{reduction_descr}
+{reduction_identity_dtype}
+{reduction_args}
+{reduction_example}"""
+    if dtype is None:
+        dtype = input.dtype
+    mask_input = _combine_input_and_mask(argmin, input, mask)
+    if input.layout == torch.strided:
+        return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
+    else:
+        raise ValueError(f'masked argmin expects strided tensor (got {input.layout} tensor)')
+
+
+@_apply_docstring_templates
 def mean(input: Tensor,
          dim: DimOrDims = None,
          *,
diff --git a/torch/_masked/_docs.py b/torch/_masked/_docs.py
index f3b3885..40b58ed 100644
--- a/torch/_masked/_docs.py
+++ b/torch/_masked/_docs.py
@@ -149,6 +149,136 @@
     tensor([                 -3, 9223372036854775807])
 """
 
+argmax_docstring = """argmax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+Returns argmax of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+The identity value of argmax operation, which is used to start the
+reduction, depends on input dtype. For instance, for float32, uint8,
+and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in argmax computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of argmax operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+    input (Tensor): the input tensor
+    dim (int): the dimension along which argmax is computed.
+
+Keyword args:
+    keepdim (bool, optional): whether the output tensor has
+      :attr:`dim` retained or not. Default: False.
+    dtype (:class:`torch.dtype`, optional): the desired data type
+      of returned tensor.  If specified, the input tensor is
+      casted to :attr:`dtype` before the operation is
+      performed. Default: None.
+    mask (:class:`torch.Tensor`, optional): the boolean tensor
+      containing the binary mask of validity of input tensor
+      elements.
+      Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+Example::
+
+    >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+    >>> input
+    tensor([[-3, -2, -1],
+            [ 0,  1,  2]])
+    >>> mask = tensor([[ True, False, True], [False, False, False]])
+    >>> mask
+    tensor([[ True, False,  True],
+            [False, False, False]])
+    >>> torch._masked.argmax(input, 1, mask=mask)
+    tensor([2, 0])
+"""
+
+argmin_docstring = """argmin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+Returns argmin of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+The identity value of argmin operation, which is used to start the
+reduction, depends on input dtype. For instance, for float32, uint8,
+and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in argmin computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of argmin operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+    input (Tensor): the input tensor
+    dim (int): the dimension along which argmin is computed.
+
+Keyword args:
+    keepdim (bool, optional): whether the output tensor has
+      :attr:`dim` retained or not. Default: False.
+    dtype (:class:`torch.dtype`, optional): the desired data type
+      of returned tensor.  If specified, the input tensor is
+      casted to :attr:`dtype` before the operation is
+      performed. Default: None.
+    mask (:class:`torch.Tensor`, optional): the boolean tensor
+      containing the binary mask of validity of input tensor
+      elements.
+      Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+Example::
+
+    >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+    >>> input
+    tensor([[-3, -2, -1],
+            [ 0,  1,  2]])
+    >>> mask = tensor([[ True, False, True], [False, False, False]])
+    >>> mask
+    tensor([[ True, False,  True],
+            [False, False, False]])
+    >>> torch._masked.argmin(input, 1, mask=mask)
+    tensor([0, 0])
+"""
+
 log_softmax_docstring = """log_softmax(input, dim, *, dtype=None, mask=None) -> Tensor
 
 Returns log_softmax of all the slices in the :attr:`input` tensor
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index a2bca3c..f4bc230 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15561,6 +15561,44 @@
         gradcheck_wrapper=gradcheck_wrapper_masked_operation
     ),
     ReductionOpInfo(
+        '_masked.argmax',
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+        skips=(
+            # FIXME (from torch.argmax): keepdim parameter is ignored when dim=None
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            # initial is not a keyword for argmax
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_reference_masked'),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation
+    ),
+    ReductionOpInfo(
+        '_masked.argmin',
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+        skips=(
+            # FIXME (from torch.argmin): keepdim parameter is ignored when dim=None
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            # initial is not a keyword for argmin
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_reference_masked'),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation
+    ),
+    ReductionOpInfo(
         '_masked.mean',
         ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None,
         method_variant=None,