Implement correction argument in torch.masked.{std,var} (#87118)
This makes the signature of `torch.masked.std` and `var` more consistent with the global namespace variant and also updates the sample inputs to repurpose the existing `sample_inputs_std_var` inputs which fully exercise the `correction` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87118
Approved by: https://github.com/cpuhrsch
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 9157713..1fff486 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -198,6 +198,7 @@
"linalg.pinv.singular": {f32, f64},
"masked.norm": {f16},
"masked.normalize": {f16},
+ "masked.var": {f16},
"masked_fill": {f16},
"masked_scatter": {f16, f32, f64},
"masked_select": {b8, f16, f32, f64, i32, i64},
diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py
index a933f06..c69d8db 100644
--- a/torch/masked/_ops.py
+++ b/torch/masked/_ops.py
@@ -1538,14 +1538,22 @@
def _std_var(
input: Union[Tensor, MaskedTensor],
- dim: DimOrDims = None,
- unbiased: Optional[bool] = False,
+ dim: DimOrDims,
+ unbiased: Optional[bool],
*,
- keepdim: Optional[bool] = False,
- dtype: Optional[DType] = None,
- mask: Optional[Tensor] = None,
- take_sqrt: Optional[bool] = False,
+ correction: Optional[int],
+ keepdim: Optional[bool],
+ dtype: Optional[DType],
+ mask: Optional[Tensor],
+ take_sqrt: Optional[bool],
) -> Tensor:
+ assert (unbiased is None or correction is None), "Only one of unbiased and correction may be given"
+ correction_int = 1
+ if unbiased is not None:
+ correction_int = 1 if unbiased else 0
+ if correction is not None:
+ correction_int = correction
+
if dtype is None:
dtype = input.dtype
if not (dtype.is_floating_point or dtype.is_complex):
@@ -1584,8 +1592,8 @@
)
if not keepdim:
count = count.reshape(total.shape)
- if unbiased:
- count = torch.subtract(count, 1)
+ if correction_int != 0:
+ count = torch.subtract(count, correction_int)
count = torch.maximum(count, count.new_zeros([]))
output = torch.divide(total, count).to(dtype=dtype)
if take_sqrt:
@@ -1601,8 +1609,9 @@
def var(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
- unbiased: Optional[bool] = False,
+ unbiased: Optional[bool] = None,
*,
+ correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
@@ -1619,6 +1628,7 @@
input=input,
dim=dim,
unbiased=unbiased,
+ correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
@@ -1630,8 +1640,9 @@
def std(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
- unbiased: Optional[bool] = False,
+ unbiased: Optional[bool] = None,
*,
+ correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
@@ -1648,6 +1659,7 @@
input=input,
dim=dim,
unbiased=unbiased,
+ correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 20025b9..57a413b 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -1,4 +1,5 @@
import unittest
+from collections.abc import Sequence
from functools import partial
from typing import List
@@ -223,51 +224,101 @@
)
+def reference_masked_std_var(
+ numpy_fn,
+):
+ ref = reference_reduction_numpy(numpy_fn)
+
+ # Translate unbiased or correction arguments into ddof
+ def func(
+ input,
+ dim=None,
+ unbiased=None,
+ *,
+ correction=None,
+ **kwargs,
+ ):
+ ddof = 1
+ if unbiased is not None:
+ ddof = 1 if unbiased else 0
+ if correction is not None:
+ ddof = correction
+
+ if isinstance(dim, Sequence):
+ dim = tuple(dim)
+
+ return ref(input, dim, ddof=ddof, **kwargs)
+
+ return func
+
+
def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked std/var."""
- for unbiased in [False, True]:
- for sample_input in sample_inputs_masked_reduction(
+ kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+ from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
+
+ def masked_samples():
+ for sample_input in sample_inputs_std_var(
op_info, device, dtype, requires_grad, **kwargs
):
- if sample_input.args:
- dim = sample_input.args[0]
- sample_input_args = (
- sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
+ if len(sample_input.args) and isinstance(sample_input.args[0], bool):
+ continue # masked.{std, var} doesn't support `.var(unbiased)`
+
+ for mask in _generate_masked_op_mask(
+ sample_input.input.shape, device, **kwargs
+ ):
+ sample_input_args, sample_input_kwargs = sample_input.args, dict(
+ mask=mask, **sample_input.kwargs
)
- sample_input_kwargs = sample_input.kwargs.copy()
- else:
- dim = sample_input.kwargs.get("dim")
- sample_input_args = sample_input.args
- sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
- if requires_grad:
- if sample_input_kwargs.get("mask") is None:
- orig_count = torch.masked.sum(
- torch.ones(sample_input.input.shape, dtype=torch.int64),
- dim,
- keepdim=True,
- )
- else:
- inmask = torch.masked._input_mask(
- sample_input.input, *sample_input_args, **sample_input_kwargs
- )
- orig_count = torch.masked.sum(
- inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
- dim,
- keepdim=True,
- mask=inmask,
- )
- if orig_count.min() <= int(unbiased) + 1:
- # Skip samples that lead to singularities in var
- # computation resulting nan values both in var and
- # autograd output that test_grad_fn cannot handle
- # correctly. Also, skip samples when the autograd output
- # for std could not be handled correctly due to torch.sqrt
- continue
- yield SampleInput(
- sample_input.input.detach().requires_grad_(requires_grad),
- args=sample_input_args,
- kwargs=sample_input_kwargs,
+ yield SampleInput(
+ sample_input.input.detach().requires_grad_(requires_grad),
+ args=sample_input_args,
+ kwargs=sample_input_kwargs,
+ )
+ if (
+ not requires_grad
+ and dtype.is_floating_point
+ and sample_input.input.ndim == 2
+ and mask is not None
+ and mask.shape == sample_input.input.shape
+ ):
+ for v in [torch.inf, -torch.inf, torch.nan]:
+ t = sample_input.input.detach()
+ t.diagonal(0, -2, -1).fill_(v)
+ yield SampleInput(
+ t.requires_grad_(requires_grad),
+ args=sample_input_args,
+ kwargs=sample_input_kwargs,
+ )
+
+ for sample_input in masked_samples():
+ correction = sample_input.kwargs.get("correction")
+ if correction is None:
+ correction = int(sample_input.kwargs.get("unbiased", True))
+
+ dim = sample_input.kwargs.get("dim", None)
+
+ if sample_input.kwargs.get("mask") is None:
+ orig_count = torch.masked.sum(
+ torch.ones(sample_input.input.shape, dtype=torch.int64),
+ dim,
+ keepdim=True,
)
+ else:
+ inmask = torch.masked._input_mask(
+ sample_input.input, *sample_input.args, **sample_input.kwargs
+ )
+ orig_count = torch.masked.sum(
+ inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
+ dim,
+ keepdim=True,
+ mask=inmask,
+ )
+ if orig_count.min() <= correction + 1:
+ # Skip samples that lead to nans in var computation
+ continue
+
+ yield sample_input
def sample_inputs_masked_softmax(
@@ -860,7 +911,7 @@
),
ReductionOpInfo(
"masked.var",
- ref=reference_reduction_numpy(np.var)
+ ref=reference_masked_std_var(np.var)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
@@ -938,7 +989,7 @@
),
ReductionOpInfo(
"masked.std",
- ref=reference_reduction_numpy(np.std)
+ ref=reference_masked_std_var(np.std)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py
index 0bbba7c..017f26f 100644
--- a/torch/testing/_internal/opinfo/utils.py
+++ b/torch/testing/_internal/opinfo/utils.py
@@ -243,11 +243,6 @@
identity = identity.cpu()
kwargs["initial"] = identity.numpy()
- if "unbiased" in keys:
- unbiased = kwargs.pop("unbiased")
- if unbiased is not None:
- kwargs["ddof"] = int(unbiased)
-
result = f(x, *args, **kwargs)
# Unsqueeze reduced dimensions if NumPy does not support keepdims