masked logasumexp/logaddexp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78291
Approved by: https://github.com/cpuhrsch
diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py
index 250a927..e4ea3cd 100644
--- a/torch/_masked/__init__.py
+++ b/torch/_masked/__init__.py
@@ -173,6 +173,7 @@
norm=(('ord', 'dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
var=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
std=(('dim', 'unbiased'), ('keepdim=False', 'dtype=None', 'mask=None')),
+ logsumexp=(('dim',), ('keepdim=False', 'dtype=None', 'mask=None')),
softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
@@ -247,7 +248,8 @@
median='median',
norm='norm',
var='variance',
- std='standard_deviation')
+ std='standard_deviation',
+ logsumexp='logsumexp')
normalization_names = dict(
softmax='softmax',
@@ -367,7 +369,7 @@
return torch.tensor(0, dtype=dtype, device=device)
elif op_name in {'prod', 'cumprod'}:
return torch.tensor(1, dtype=dtype, device=device)
- elif op_name in {'amax', 'argmax'}:
+ elif op_name in {'amax', 'argmax', 'logsumexp'}:
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
@@ -763,7 +765,7 @@
"""
if callable(op):
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin',
- 'argmax', 'argmin', 'mean', 'median', 'norm', 'var', 'std'}
+ 'argmax', 'argmin', 'mean', 'median', 'norm', 'var', 'std', 'logsumexp'}
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize', 'cumsum', 'cumprod'}
if is_reduction:
if op.__name__ == 'norm':
@@ -1103,6 +1105,42 @@
@_apply_docstring_templates
+def logsumexp(input: Tensor,
+ dim: DimOrDims = None,
+ *,
+ keepdim: bool = False,
+ dtype: Optional[DType] = None,
+ mask: Optional[Tensor] = None) -> Tensor:
+ if dtype is None:
+ dtype = input.dtype
+ dim_ = _canonical_dim(dim, input.ndim)
+ mask_input = _combine_input_and_mask(logsumexp, input, mask)
+ if input.layout == torch.strided:
+ return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
+ else:
+ raise ValueError(f'masked logsumexp expects strided tensor (got {input.layout} tensor)')
+
+
+# TODO: Add docstring; currently they're only set up for reductions and normalizations
+# @_apply_docstring_templates
+def logaddexp(input: Tensor,
+ other: Tensor,
+ *,
+ dtype: Optional[DType] = None,
+ input_mask: Optional[Tensor] = None,
+ other_mask: Optional[Tensor] = None) -> Tensor:
+ if dtype is None:
+ dtype = input.dtype
+ if input.layout == torch.strided and other.layout == torch.strided:
+ mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
+ mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
+ return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
+ else:
+ raise ValueError(
+ f'masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)')
+
+
+@_apply_docstring_templates
def norm(input: Tensor,
ord: Optional[float] = 2.0,
dim: DimOrDims = None,
diff --git a/torch/_masked/_docs.py b/torch/_masked/_docs.py
index da2fa68..36961ed 100644
--- a/torch/_masked/_docs.py
+++ b/torch/_masked/_docs.py
@@ -450,6 +450,74 @@
[ nan, nan, nan]])
"""
+logsumexp_docstring = """logsumexp(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
+
+Returns logsumexp of all the elements in the :attr:`input`
+tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
+elements are masked out according to the boolean tensor
+:attr:`mask`.
+
+The identity value of logsumexp operation, which is used to start the reduction, is ``-2147483648``.
+
+If :attr:`keepdim` is ``True``, the output tensor is of the same size
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
+size 1. Otherwise, :attr:`dim` is squeezed (see
+:func:`torch.squeeze`), resulting in the output tensor having 1 (or
+``len(dim)``) fewer dimension(s).
+
+The boolean tensor :attr:`mask` defines the "validity" of
+:attr:`input` tensor elements: if :attr:`mask` element is True
+then the corresponding element in :attr:`input` tensor will be
+included in logsumexp computation, otherwise the element is
+ignored.
+
+When all elements of :attr:`input` along the given dimension
+:attr:`dim` are ignored (fully masked-out), the corresponding element
+of the output tensor will have undefined value: it may or may not
+correspond to the identity value of logsumexp operation; the
+choice may correspond to the value that leads to the most efficient
+storage of :attr:`output` tensor.
+
+The mask of the output tensor can be computed as
+``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
+dtype=torch.bool)``.
+
+The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
+don't need to match, but they must be :ref:`broadcastable
+<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
+tensor must not be greater than of the :attr:`input` tensor.
+
+Args:
+ input (Tensor): the input tensor
+ dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
+ Default: None that is equivalent to ``tuple(range(input.ndim))``.
+
+Keyword args:
+ keepdim (bool, optional): whether the output tensor has
+ :attr:`dim` retained or not. Default: False.
+ dtype (:class:`torch.dtype`, optional): the desired data type
+ of returned tensor. If specified, the input tensor is
+ casted to :attr:`dtype` before the operation is
+ performed. Default: None.
+ mask (:class:`torch.Tensor`, optional): the boolean tensor
+ containing the binary mask of validity of input tensor
+ elements.
+ Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
+
+Example::
+
+ >>> input = tensor([[-3, -2, -1], [ 0, 1, 2]])
+ >>> input
+ tensor([[-3, -2, -1],
+ [ 0, 1, 2]])
+ >>> mask = tensor([[ True, False, True], [False, False, False]])
+ >>> mask
+ tensor([[ True, False, True],
+ [False, False, False]])
+ >>> torch._masked.logsumexp(input, 1, mask=mask)
+ tensor([ 0, -9223372036854775808])
+"""
+
mean_docstring = """mean(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor
Returns mean of all the elements in the :attr:`input`
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ce3f208..5f72e3c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7740,6 +7740,23 @@
return inputs
+def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+ """Sample inputs for masked logaddexp.
+ """
+ inputs: List[SampleInput] = []
+ shapes = [(S,), (S, S), (S, M, S)]
+ input_mask_lists = [list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes]
+ other_mask_lists = [list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes]
+
+ for shape, input_masks, other_masks in zip(shapes, input_mask_lists, other_mask_lists):
+ for input_mask, other_mask in zip(input_masks, other_masks):
+ input = make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+ other = make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+ inputs.append(SampleInput(input.clone().requires_grad_(requires_grad),
+ args=(other.clone().requires_grad_(requires_grad),),
+ kwargs=dict(input_mask=input_mask, other_mask=other_mask)))
+ return inputs
+
def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked normalize.
"""
@@ -10173,6 +10190,26 @@
return output
+def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs):
+ """Gradcheck wrapper for masked pointwise operations. Assumes that the result
+ will be masked iff both tensors are masked at a specific index
+
+ When mask is specified, replaces masked-out elements with zeros.
+
+ Use for operations that produce non-finite masked-out elements,
+ for instance, for minimum and maximum reductions.
+ """
+ output = op(input, *args, **kwargs)
+ input_mask = kwargs.get('input_mask')
+ other_mask = kwargs.get('other_mask')
+ if input_mask is not None and other_mask is not None:
+ combined_mask = torch.logical_and(input_mask, other_mask)
+ new_kwargs = dict(mask=combined_mask, **kwargs)
+ output_mask = torch._masked._input_mask(input, *args, **new_kwargs)
+ output = torch.where(output_mask, output, output.new_zeros([]))
+ return output
+
+
def reference_reduction_numpy(f, supports_keepdims=True):
"""Wraps a NumPy reduction operator.
@@ -19109,6 +19146,46 @@
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo(
+ '_masked.logaddexp',
+ dtypes=floating_types_and(torch.bfloat16),
+ supports_out=False,
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ check_batched_forward_grad=False,
+ skips=(
+ DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
+ ),
+ sample_inputs_func=sample_inputs_masked_logaddexp,
+ gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation
+ ),
+ ReductionOpInfo(
+ '_masked.logsumexp',
+ dtypes=all_types_and(torch.bfloat16),
+ dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ method_variant=None,
+ nan_policy='propagate',
+ supports_out=False,
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+ # FIXME: reduces all dimensions when dim=[]
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+ # Identity can't be -torch.inf without overflow
+ DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_empty_tensor_empty_slice'),
+ # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+ # all the values are the same except for -inf vs nan
+ DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
+ ),
+ sample_inputs_func=sample_inputs_masked_reduction,
+ gradcheck_wrapper=gradcheck_wrapper_masked_operation
+ ),
+ OpInfo(
"nn.functional.ctc_loss",
ref=_NOTHING,
dtypes=floating_types(),