masked argmin and argmax
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74525
Approved by: https://github.com/cpuhrsch
diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py
index 24f984f..6f962d0 100644
--- a/torch/_masked/__init__.py
+++ b/torch/_masked/__init__.py
@@ -163,6 +163,8 @@
prod=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
amin=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
amax=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
+ argmin=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')),
+ argmax=(('dim__as_int',), ('keepdim=False', 'dtype=None', 'mask=None')),
mean=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
@@ -227,6 +229,8 @@
prod='product',
amax='maximum',
amin='minimum',
+ argmax='argmax',
+ argmin='argmin',
mean='mean',
norm='norm',
var='variance',
@@ -345,12 +349,12 @@
return torch.tensor(0, dtype=dtype, device=device)
elif op_name == 'prod':
return torch.tensor(1, dtype=dtype, device=device)
- elif op_name == 'amax':
+ elif op_name in {'amax', 'argmax'}:
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
- elif op_name == 'amin':
+ elif op_name in {'amin', 'argmin'}:
if torch.is_floating_point(input):
return torch.tensor(torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
@@ -621,7 +625,7 @@
"""Return output mask of masked operation applied to given arguments.
"""
if callable(op):
- is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm', 'var', 'std'}
+ is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'argmax', 'argmin', 'mean', 'norm', 'var', 'std'}
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
if is_reduction:
if op.__name__ == 'norm':
@@ -780,6 +784,50 @@
@_apply_docstring_templates
+def argmax(input: Tensor,
+ dim: int = None,
+ *,
+ keepdim: Optional[bool] = False,
+ dtype: Optional[DType] = None,
+ mask: Optional[Tensor] = None) -> Tensor:
+ """\
+{reduction_signature}
+{reduction_descr}
+{reduction_identity_dtype}
+{reduction_args}
+{reduction_example}"""
+ if dtype is None:
+ dtype = input.dtype
+ mask_input = _combine_input_and_mask(argmax, input, mask)
+ if input.layout == torch.strided:
+ return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
+ else:
+ raise ValueError(f'masked argmax expects strided tensor (got {input.layout} tensor)')
+
+
+@_apply_docstring_templates
+def argmin(input: Tensor,
+ dim: int = None,
+ *,
+ keepdim: Optional[bool] = False,
+ dtype: Optional[DType] = None,
+ mask: Optional[Tensor] = None) -> Tensor:
+ """\
+{reduction_signature}
+{reduction_descr}
+{reduction_identity_dtype}
+{reduction_args}
+{reduction_example}"""
+ if dtype is None:
+ dtype = input.dtype
+ mask_input = _combine_input_and_mask(argmin, input, mask)
+ if input.layout == torch.strided:
+ return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
+ else:
+ raise ValueError(f'masked argmin expects strided tensor (got {input.layout} tensor)')
+
+
+@_apply_docstring_templates
def mean(input: Tensor,
dim: DimOrDims = None,
*,
diff --git a/torch/_masked/_docs.py b/torch/_masked/_docs.py
index f3b3885..40b58ed 100644
--- a/torch/_masked/_docs.py
+++ b/torch/_masked/_docs.py
@@ -149,6 +149,136 @@
tensor([ -3, 9223372036854775807])
"""
+argmax_docstring = """argmax(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+Returns argmax of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+The identity value of argmax operation, which is used to start the
+reduction, depends on input dtype. For instance, for float32, uint8,
+and int32 dtypes, the identity values are ``-inf``, ``0``, and ``-2147483648``, respectively.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in argmax computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of argmax operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+ input (Tensor): the input tensor
+ dim (int): the dimension along which argmax is computed.
+
+Keyword args:
+ keepdim (bool, optional): whether the output tensor has
+ :attr:`dim` retained or not. Default: False.
+ dtype (:class:`torch.dtype`, optional): the desired data type
+ of returned tensor. If specified, the input tensor is
+ casted to :attr:`dtype` before the operation is
+ performed. Default: None.
+ mask (:class:`torch.Tensor`, optional): the boolean tensor
+ containing the binary mask of validity of input tensor
+ elements.
+ Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+Example::
+
+ >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+ >>> input
+ tensor([[-3, -2, -1],
+ [ 0, 1, 2]])
+ >>> mask = tensor([[ True, False, True], [False, False, False]])
+ >>> mask
+ tensor([[ True, False, True],
+ [False, False, False]])
+ >>> torch._masked.argmax(input, 1, mask=mask)
+ tensor([2, 0])
+"""
+
+argmin_docstring = """argmin(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+Returns argmin of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+The identity value of argmin operation, which is used to start the
+reduction, depends on input dtype. For instance, for float32, uint8,
+and int32 dtypes, the identity values are ``inf``, ``255``, and ``2147483647``, respectively.
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in argmin computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of argmin operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+ input (Tensor): the input tensor
+ dim (int): the dimension along which argmin is computed.
+
+Keyword args:
+ keepdim (bool, optional): whether the output tensor has
+ :attr:`dim` retained or not. Default: False.
+ dtype (:class:`torch.dtype`, optional): the desired data type
+ of returned tensor. If specified, the input tensor is
+ casted to :attr:`dtype` before the operation is
+ performed. Default: None.
+ mask (:class:`torch.Tensor`, optional): the boolean tensor
+ containing the binary mask of validity of input tensor
+ elements.
+ Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+Example::
+
+ >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+ >>> input
+ tensor([[-3, -2, -1],
+ [ 0, 1, 2]])
+ >>> mask = tensor([[ True, False, True], [False, False, False]])
+ >>> mask
+ tensor([[ True, False, True],
+ [False, False, False]])
+ >>> torch._masked.argmin(input, 1, mask=mask)
+ tensor([0, 0])
+"""
+
log_softmax_docstring = """log_softmax(input, dim, *, dtype=None, mask=None) -> Tensor
Returns log_softmax of all the slices in the :attr:`input` tensor
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index a2bca3c..f4bc230 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -15561,6 +15561,44 @@
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
ReductionOpInfo(
+ '_masked.argmax',
+ supports_out=False,
+ supports_multiple_dims=False,
+ supports_autograd=False,
+ dtypes=all_types_and(torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+ skips=(
+ # FIXME (from torch.argmax): keepdim parameter is ignored when dim=None
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+ # initial is not a keyword for argmax
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_reference_masked'),
+ # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ ),
+ sample_inputs_func=sample_inputs_masked_reduction,
+ gradcheck_wrapper=gradcheck_wrapper_masked_operation
+ ),
+ ReductionOpInfo(
+ '_masked.argmin',
+ supports_out=False,
+ supports_multiple_dims=False,
+ supports_autograd=False,
+ dtypes=all_types_and(torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+ skips=(
+ # FIXME (from torch.argmin): keepdim parameter is ignored when dim=None
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+ # initial is not a keyword for argmin
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_reference_masked'),
+ # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ ),
+ sample_inputs_func=sample_inputs_masked_reduction,
+ gradcheck_wrapper=gradcheck_wrapper_masked_operation
+ ),
+ ReductionOpInfo(
'_masked.mean',
ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None,
method_variant=None,