Implement correction argument in torch.masked.{std,var} (#87118)

This makes the signature of `torch.masked.std` and `var` more consistent with the global namespace variant and also updates the sample inputs to repurpose the existing `sample_inputs_std_var` inputs which fully exercise the `correction` argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87118
Approved by: https://github.com/cpuhrsch
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 9157713..1fff486 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -198,6 +198,7 @@
     "linalg.pinv.singular": {f32, f64},
     "masked.norm": {f16},
     "masked.normalize": {f16},
+    "masked.var": {f16},
     "masked_fill": {f16},
     "masked_scatter": {f16, f32, f64},
     "masked_select": {b8, f16, f32, f64, i32, i64},
diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py
index a933f06..c69d8db 100644
--- a/torch/masked/_ops.py
+++ b/torch/masked/_ops.py
@@ -1538,14 +1538,22 @@
 
 def _std_var(
     input: Union[Tensor, MaskedTensor],
-    dim: DimOrDims = None,
-    unbiased: Optional[bool] = False,
+    dim: DimOrDims,
+    unbiased: Optional[bool],
     *,
-    keepdim: Optional[bool] = False,
-    dtype: Optional[DType] = None,
-    mask: Optional[Tensor] = None,
-    take_sqrt: Optional[bool] = False,
+    correction: Optional[int],
+    keepdim: Optional[bool],
+    dtype: Optional[DType],
+    mask: Optional[Tensor],
+    take_sqrt: Optional[bool],
 ) -> Tensor:
+    assert (unbiased is None or correction is None), "Only one of unbiased and correction may be given"
+    correction_int = 1
+    if unbiased is not None:
+        correction_int = 1 if unbiased else 0
+    if correction is not None:
+        correction_int = correction
+
     if dtype is None:
         dtype = input.dtype
         if not (dtype.is_floating_point or dtype.is_complex):
@@ -1584,8 +1592,8 @@
             )
         if not keepdim:
             count = count.reshape(total.shape)
-        if unbiased:
-            count = torch.subtract(count, 1)
+        if correction_int != 0:
+            count = torch.subtract(count, correction_int)
             count = torch.maximum(count, count.new_zeros([]))
         output = torch.divide(total, count).to(dtype=dtype)
         if take_sqrt:
@@ -1601,8 +1609,9 @@
 def var(
     input: Union[Tensor, MaskedTensor],
     dim: DimOrDims = None,
-    unbiased: Optional[bool] = False,
+    unbiased: Optional[bool] = None,
     *,
+    correction: Optional[int] = None,
     keepdim: Optional[bool] = False,
     dtype: Optional[DType] = None,
     mask: Optional[Tensor] = None,
@@ -1619,6 +1628,7 @@
         input=input,
         dim=dim,
         unbiased=unbiased,
+        correction=correction,
         keepdim=keepdim,
         dtype=dtype,
         mask=mask,
@@ -1630,8 +1640,9 @@
 def std(
     input: Union[Tensor, MaskedTensor],
     dim: DimOrDims = None,
-    unbiased: Optional[bool] = False,
+    unbiased: Optional[bool] = None,
     *,
+    correction: Optional[int] = None,
     keepdim: Optional[bool] = False,
     dtype: Optional[DType] = None,
     mask: Optional[Tensor] = None,
@@ -1648,6 +1659,7 @@
         input=input,
         dim=dim,
         unbiased=unbiased,
+        correction=correction,
         keepdim=keepdim,
         dtype=dtype,
         mask=mask,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 20025b9..57a413b 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -1,4 +1,5 @@
 import unittest
+from collections.abc import Sequence
 from functools import partial
 from typing import List
 
@@ -223,51 +224,101 @@
             )
 
 
+def reference_masked_std_var(
+    numpy_fn,
+):
+    ref = reference_reduction_numpy(numpy_fn)
+
+    # Translate unbiased or correction arguments into ddof
+    def func(
+        input,
+        dim=None,
+        unbiased=None,
+        *,
+        correction=None,
+        **kwargs,
+    ):
+        ddof = 1
+        if unbiased is not None:
+            ddof = 1 if unbiased else 0
+        if correction is not None:
+            ddof = correction
+
+        if isinstance(dim, Sequence):
+            dim = tuple(dim)
+
+        return ref(input, dim, ddof=ddof, **kwargs)
+
+    return func
+
+
 def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
     """Sample inputs for masked std/var."""
-    for unbiased in [False, True]:
-        for sample_input in sample_inputs_masked_reduction(
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+    from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
+
+    def masked_samples():
+        for sample_input in sample_inputs_std_var(
             op_info, device, dtype, requires_grad, **kwargs
         ):
-            if sample_input.args:
-                dim = sample_input.args[0]
-                sample_input_args = (
-                    sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
+            if len(sample_input.args) and isinstance(sample_input.args[0], bool):
+                continue  # masked.{std, var} doesn't support `.var(unbiased)`
+
+            for mask in _generate_masked_op_mask(
+                sample_input.input.shape, device, **kwargs
+            ):
+                sample_input_args, sample_input_kwargs = sample_input.args, dict(
+                    mask=mask, **sample_input.kwargs
                 )
-                sample_input_kwargs = sample_input.kwargs.copy()
-            else:
-                dim = sample_input.kwargs.get("dim")
-                sample_input_args = sample_input.args
-                sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
-            if requires_grad:
-                if sample_input_kwargs.get("mask") is None:
-                    orig_count = torch.masked.sum(
-                        torch.ones(sample_input.input.shape, dtype=torch.int64),
-                        dim,
-                        keepdim=True,
-                    )
-                else:
-                    inmask = torch.masked._input_mask(
-                        sample_input.input, *sample_input_args, **sample_input_kwargs
-                    )
-                    orig_count = torch.masked.sum(
-                        inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
-                        dim,
-                        keepdim=True,
-                        mask=inmask,
-                    )
-                if orig_count.min() <= int(unbiased) + 1:
-                    # Skip samples that lead to singularities in var
-                    # computation resulting nan values both in var and
-                    # autograd output that test_grad_fn cannot handle
-                    # correctly. Also, skip samples when the autograd output
-                    # for std could not be handled correctly due to torch.sqrt
-                    continue
-            yield SampleInput(
-                sample_input.input.detach().requires_grad_(requires_grad),
-                args=sample_input_args,
-                kwargs=sample_input_kwargs,
+                yield SampleInput(
+                    sample_input.input.detach().requires_grad_(requires_grad),
+                    args=sample_input_args,
+                    kwargs=sample_input_kwargs,
+                )
+                if (
+                    not requires_grad
+                    and dtype.is_floating_point
+                    and sample_input.input.ndim == 2
+                    and mask is not None
+                    and mask.shape == sample_input.input.shape
+                ):
+                    for v in [torch.inf, -torch.inf, torch.nan]:
+                        t = sample_input.input.detach()
+                        t.diagonal(0, -2, -1).fill_(v)
+                        yield SampleInput(
+                            t.requires_grad_(requires_grad),
+                            args=sample_input_args,
+                            kwargs=sample_input_kwargs,
+                        )
+
+    for sample_input in masked_samples():
+        correction = sample_input.kwargs.get("correction")
+        if correction is None:
+            correction = int(sample_input.kwargs.get("unbiased", True))
+
+        dim = sample_input.kwargs.get("dim", None)
+
+        if sample_input.kwargs.get("mask") is None:
+            orig_count = torch.masked.sum(
+                torch.ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
             )
+        else:
+            inmask = torch.masked._input_mask(
+                sample_input.input, *sample_input.args, **sample_input.kwargs
+            )
+            orig_count = torch.masked.sum(
+                inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+                mask=inmask,
+            )
+        if orig_count.min() <= correction + 1:
+            # Skip samples that lead to nans in var computation
+            continue
+
+        yield sample_input
 
 
 def sample_inputs_masked_softmax(
@@ -860,7 +911,7 @@
     ),
     ReductionOpInfo(
         "masked.var",
-        ref=reference_reduction_numpy(np.var)
+        ref=reference_masked_std_var(np.var)
         if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
         else None,
         method_variant=None,
@@ -938,7 +989,7 @@
     ),
     ReductionOpInfo(
         "masked.std",
-        ref=reference_reduction_numpy(np.std)
+        ref=reference_masked_std_var(np.std)
         if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
         else None,
         method_variant=None,
diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py
index 0bbba7c..017f26f 100644
--- a/torch/testing/_internal/opinfo/utils.py
+++ b/torch/testing/_internal/opinfo/utils.py
@@ -243,11 +243,6 @@
                     identity = identity.cpu()
                 kwargs["initial"] = identity.numpy()
 
-        if "unbiased" in keys:
-            unbiased = kwargs.pop("unbiased")
-            if unbiased is not None:
-                kwargs["ddof"] = int(unbiased)
-
         result = f(x, *args, **kwargs)
 
         # Unsqueeze reduced dimensions if NumPy does not support keepdims