masked logasumexp/logaddexp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78291

Approved by: https://github.com/cpuhrsch
diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py
index 250a927..e4ea3cd 100644
--- a/torch/_masked/__init__.py
+++ b/torch/_masked/__init__.py
@@ -173,6 +173,7 @@
         norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
         std=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
+        logsumexp=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
         softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
         log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
         softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
@@ -247,7 +248,8 @@
         median='median',
         norm='norm',
         var='variance',
-        std='standard_deviation')
+        std='standard_deviation',
+        logsumexp='logsumexp')
 
     normalization_names = dict(
         softmax='softmax',
@@ -367,7 +369,7 @@
         return torch.tensor(0, dtype=dtype, device=device)
     elif op_name in {'prod', 'cumprod'}:
         return torch.tensor(1, dtype=dtype, device=device)
-    elif op_name in {'amax', 'argmax'}:
+    elif op_name in {'amax', 'argmax', 'logsumexp'}:
         if torch.is_floating_point(input):
             return torch.tensor(-torch.inf, dtype=dtype, device=device)
         elif torch.is_signed(input) or dtype == torch.uint8:
@@ -763,7 +765,7 @@
     """
     if callable(op):
         is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin',
-                                       'argmax', 'argmin', 'mean', 'median', 'norm', 'var', 'std'}
+                                       'argmax', 'argmin', 'mean', 'median', 'norm', 'var', 'std', 'logsumexp'}
         is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize', 'cumsum', 'cumprod'}
         if is_reduction:
             if op.__name__ == 'norm':
@@ -1103,6 +1105,42 @@
 
 
 @_apply_docstring_templates
+def logsumexp(input: Tensor,
+              dim: DimOrDims = None,
+              *,
+              keepdim: bool = False,
+              dtype: Optional[DType] = None,
+              mask: Optional[Tensor] = None) -> Tensor:
+    if dtype is None:
+        dtype = input.dtype
+    dim_ = _canonical_dim(dim, input.ndim)
+    mask_input = _combine_input_and_mask(logsumexp, input, mask)
+    if input.layout == torch.strided:
+        return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
+    else:
+        raise ValueError(f'masked logsumexp expects strided tensor (got {input.layout} tensor)')
+
+
+# TODO: Add docstring; currently they're only set up for reductions and normalizations
+# @_apply_docstring_templates
+def logaddexp(input: Tensor,
+              other: Tensor,
+              *,
+              dtype: Optional[DType] = None,
+              input_mask: Optional[Tensor] = None,
+              other_mask: Optional[Tensor] = None) -> Tensor:
+    if dtype is None:
+        dtype = input.dtype
+    if input.layout == torch.strided and other.layout == torch.strided:
+        mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
+        mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
+        return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
+    else:
+        raise ValueError(
+            f'masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)')
+
+
+@_apply_docstring_templates
 def norm(input: Tensor,
          ord: Optional[float] = 2.0,
          dim: DimOrDims = None,
diff --git a/torch/_masked/_docs.py b/torch/_masked/_docs.py
index da2fa68..36961ed 100644
--- a/torch/_masked/_docs.py
+++ b/torch/_masked/_docs.py
@@ -450,6 +450,74 @@
             [    nan,     nan,     nan]])
 """
 
+logsumexp_docstring = """logsumexp(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+
+Returns logsumexp of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+
+The identity value of logsumexp operation, which is used to start the reduction, is ``-2147483648``.
+
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in logsumexp computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of logsumexp operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+    input (Tensor): the input tensor
+    dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
+      Default: None that is equivalent to ``tuple(range(input.ndim))``.
+
+Keyword args:
+    keepdim (bool, optional): whether the output tensor has
+      :attr:`dim` retained or not. Default: False.
+    dtype (:class:`torch.dtype`, optional): the desired data type
+      of returned tensor.  If specified, the input tensor is
+      casted to :attr:`dtype` before the operation is
+      performed. Default: None.
+    mask (:class:`torch.Tensor`, optional): the boolean tensor
+      containing the binary mask of validity of input tensor
+      elements.
+      Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+
+Example::
+
+    >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+    >>> input
+    tensor([[-3, -2, -1],
+            [ 0,  1,  2]])
+    >>> mask = tensor([[ True, False, True], [False, False, False]])
+    >>> mask
+    tensor([[ True, False,  True],
+            [False, False, False]])
+    >>> torch._masked.logsumexp(input, 1, mask=mask)
+    tensor([                   0, -9223372036854775808])
+"""
+
 mean_docstring = """mean(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
 
 Returns mean of all the elements in the :attr:`input`
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ce3f208..5f72e3c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7740,6 +7740,23 @@
 
     return inputs
 
+def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked logaddexp.
+    """
+    inputs: List[SampleInput] = []
+    shapes = [(S,), (S, S), (S, M, S)]
+    input_mask_lists = [list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes]
+    other_mask_lists = [list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes]
+
+    for shape, input_masks, other_masks in zip(shapes, input_mask_lists, other_mask_lists):
+        for input_mask, other_mask in zip(input_masks, other_masks):
+            input = make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+            other = make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+            inputs.append(SampleInput(input.clone().requires_grad_(requires_grad),
+                                      args=(other.clone().requires_grad_(requires_grad),),
+                                      kwargs=dict(input_mask=input_mask, other_mask=other_mask)))
+    return inputs
+
 def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
     """Sample inputs for masked normalize.
     """
@@ -10173,6 +10190,26 @@
     return output
 
 
+def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs):
+    """Gradcheck wrapper for masked pointwise operations. Assumes that the result
+    will be masked iff both tensors are masked at a specific index
+
+    When mask is specified, replaces masked-out elements with zeros.
+
+    Use for operations that produce non-finite masked-out elements,
+    for instance, for minimum and maximum reductions.
+    """
+    output = op(input, *args, **kwargs)
+    input_mask = kwargs.get('input_mask')
+    other_mask = kwargs.get('other_mask')
+    if input_mask is not None and other_mask is not None:
+        combined_mask = torch.logical_and(input_mask, other_mask)
+        new_kwargs = dict(mask=combined_mask, **kwargs)
+        output_mask = torch._masked._input_mask(input, *args, **new_kwargs)
+        output = torch.where(output_mask, output, output.new_zeros([]))
+    return output
+
+
 def reference_reduction_numpy(f, supports_keepdims=True):
     """Wraps a NumPy reduction operator.
 
@@ -19109,6 +19146,46 @@
         supports_fwgrad_bwgrad=True,
         supports_out=False),
     OpInfo(
+        '_masked.logaddexp',
+        dtypes=floating_types_and(torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+        ),
+        sample_inputs_func=sample_inputs_masked_logaddexp,
+        gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation
+    ),
+    ReductionOpInfo(
+        '_masked.logsumexp',
+        dtypes=all_types_and(torch.bfloat16),
+        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        nan_policy='propagate',
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # Identity can't be -torch.inf without overflow
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_empty_tensor_empty_slice'),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+            # all the values are the same except for -inf vs nan
+            DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation
+    ),
+    OpInfo(
         "nn.functional.ctc_loss",
         ref=_NOTHING,
         dtypes=floating_types(),