blob: d45ddc791134ffec32f75ad2980f56accde28d8d [file] [log] [blame]
"""Tests for masked operations.
"""
import itertools
import torch
from torch.testing._internal.common_utils import \
(TestCase, suppress_warnings)
from torch.testing._internal.common_methods_invocations import \
(op_db,)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, onlyNativeDeviceTypes)
def apply_masked_normalization_along_dim(op, x, dim, dtype=None, mask=None):
"""Applies normalization op along given dimension to strided x
elements that are valid according to mask tensor.
"""
if x.ndim == 0: # scalar input
return op(x, dim, dtype=dtype)
y = torch.zeros_like(x, dtype=dtype)
inpmask = torch._masked._input_mask(x, mask=mask)
dim_ = dim % x.ndim
left_ranges = tuple(map(range, x.shape[:dim_]))
right_ranges = tuple(map(range, x.shape[dim_ + 1:]))
for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
indices = inpmask[s].argwhere()
y[s][indices] = op(x[s][indices], 0, dtype=dtype)
return y
reference_functions = dict(
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
)
masked_ops = [op for op in op_db if op.name.startswith('_masked.')]
masked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions]
class TestMasked(TestCase):
@onlyNativeDeviceTypes
@suppress_warnings
@ops(masked_ops_with_references)
def test_reference_masked(self, device, dtype, op):
ref_op = reference_functions[op.name.rsplit('.', 1)[-1]]
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in sample_inputs:
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
actual = op.op(t_inp, *t_args, **t_kwargs)
expected = ref_op(t_inp, *t_args, **t_kwargs)
outmask = torch._masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
actual = torch.where(outmask, actual, actual.new_zeros([]))
expected = torch.where(outmask, expected, expected.new_zeros([]))
self.assertEqual(actual, expected, exact_device=False)
instantiate_device_type_tests(TestMasked, globals())