Add MaskedTensor support to *_like API (#128637)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128637
Approved by: https://github.com/cpuhrsch
diff --git a/docs/source/masked.rst b/docs/source/masked.rst
index 60dd67f..ddaf127 100644
--- a/docs/source/masked.rst
+++ b/docs/source/masked.rst
@@ -277,12 +277,18 @@
chunk
column_stack
dsplit
+ empty_like
flatten
+ full_like
hsplit
hstack
kron
meshgrid
narrow
+ ones_like
+ rand_like
+ randint_like
+ randn_like
ravel
select
split
@@ -290,6 +296,7 @@
transpose
vsplit
vstack
+ zeros_like
Tensor.expand
Tensor.expand_as
Tensor.reshape
@@ -301,6 +308,7 @@
.. py:module:: torch.masked.maskedtensor.binary
.. py:module:: torch.masked.maskedtensor.core
.. py:module:: torch.masked.maskedtensor.creation
+.. py:module:: torch.masked.maskedtensor.like
.. py:module:: torch.masked.maskedtensor.passthrough
.. py:module:: torch.masked.maskedtensor.reductions
.. py:module:: torch.masked.maskedtensor.unary
\ No newline at end of file
diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py
index bf42cf0..0817822 100644
--- a/test/test_maskedtensor.py
+++ b/test/test_maskedtensor.py
@@ -19,6 +19,7 @@
binary_ufuncs,
reduction_ops,
unary_ufuncs,
+ ops_and_refs,
)
from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
@@ -26,6 +27,7 @@
from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
from torch.masked.maskedtensor.reductions import REDUCE_NAMES
+from torch.masked.maskedtensor.like import LIKE_NAMES
def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
@@ -812,9 +814,13 @@
def is_reduction(op):
return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"}
+def is_like(op):
+ return op.name in LIKE_NAMES
+
mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)]
mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)]
mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)]
+mt_like_funcs = [op for op in ops_and_refs if is_like(op)]
MASKEDTENSOR_FLOAT_TYPES = {
torch.float16,
@@ -822,6 +828,16 @@
torch.float64,
}
+MASKEDTENSOR_SUPPORTED_TYPES = [
+ *MASKEDTENSOR_FLOAT_TYPES,
+ torch.bool,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+]
+
+
class TestOperators(TestCase):
def _convert_mt_args(self, args, mask, layout):
return [
@@ -967,6 +983,43 @@
self._test_reduction_equality(device, dtype, op, layout)
+ @ops(mt_like_funcs, allowed_dtypes=MASKEDTENSOR_SUPPORTED_TYPES) # type: ignore[arg-type]
+ @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
+ def test_like(self, device, dtype, op, layout):
+ samples = op.sample_inputs(device, dtype, requires_grad=dtype in MASKEDTENSOR_FLOAT_TYPES)
+
+ for sample in samples:
+ input = sample.input
+ sample_args, sample_kwargs = sample.args, sample.kwargs
+ mask = _create_random_mask(input.shape, device)
+
+ if layout == torch.sparse_coo:
+ mask = mask.to_sparse_coo().coalesce()
+ input = input.sparse_mask(mask)
+ elif layout == torch.sparse_csr:
+ if input.ndim != 2 or mask.ndim != 2:
+ continue
+ mask = mask.to_sparse_csr()
+ input = input.sparse_mask(mask)
+
+ mt = masked_tensor(input, mask)
+
+ try:
+ t_result = op(input, *sample_args, **sample_kwargs)
+ except NotImplementedError:
+ self.skipTest(f"{op.name} is not supported for {layout}")
+
+ mt_result = op(mt, *sample_args, **sample_kwargs)
+
+ if 'rand' not in op.name and op.name != 'empty_like':
+ _compare_mt_t(mt_result, t_result.to_dense())
+
+ self.assertEqual(mt_result.shape, t_result.shape)
+ self.assertEqual(mt_result.dtype, t_result.dtype)
+ self.assertEqual(mt_result.layout, t_result.layout)
+ self.assertEqual(mt_result.device, t_result.device)
+ self.assertEqual(mt_result.requires_grad, t_result.requires_grad)
+
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
diff --git a/torch/masked/maskedtensor/__init__.py b/torch/masked/maskedtensor/__init__.py
index e38e03c..9c908e9 100644
--- a/torch/masked/maskedtensor/__init__.py
+++ b/torch/masked/maskedtensor/__init__.py
@@ -3,6 +3,7 @@
from .binary import _apply_native_binary, _is_native_binary
from .core import is_masked_tensor, MaskedTensor
+from .like import _apply_like_fn, _is_like_fn
from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
from .reductions import _apply_reduction, _is_reduction
from .unary import _apply_native_unary, _is_native_unary
diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py
index 7344a6e..febb0a9 100644
--- a/torch/masked/maskedtensor/_ops_refs.py
+++ b/torch/masked/maskedtensor/_ops_refs.py
@@ -14,6 +14,7 @@
is_masked_tensor,
MaskedTensor,
)
+from .like import _apply_like_fn, LIKE_FNS
from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
from .reductions import (
_apply_reduction,
@@ -270,6 +271,11 @@
return _apply_native_binary(func, *args, **kwargs)
+@register_dispatch_func(LIKE_FNS)
+def _general_like(func, *args, **kwargs):
+ return _apply_like_fn(func, *args, **kwargs)
+
+
@register_dispatch_func([torch.ops.aten.stride])
def stride(func, *args, **kwargs):
return None
@@ -365,13 +371,6 @@
return MaskedTensor(result_data, mask)
-@register_dispatch_func([torch.ops.aten.ones_like])
-def ones_like(func, *args, **kwargs):
- _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
- result_data = func(_get_data(args[0]), **kwargs)
- return MaskedTensor(result_data, _maybe_get_mask(args[0]))
-
-
@register_dispatch_func([torch.ops.aten._softmax_backward_data])
def _softmax_backward_data(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
diff --git a/torch/masked/maskedtensor/like.py b/torch/masked/maskedtensor/like.py
new file mode 100644
index 0000000..3a410be
--- /dev/null
+++ b/torch/masked/maskedtensor/like.py
@@ -0,0 +1,31 @@
+# mypy: allow-untyped-defs
+"""
+These are the creation ops using the ``*_like`` API that need to support the
+`MaskedTensor`. The wrapper just applies the function to the masked data then convert
+it to a masked tensor using the mask from the given tensor.
+"""
+
+import torch
+from torch.masked.maskedtensor.core import _get_data, _maybe_get_mask, MaskedTensor
+
+
+LIKE_NAMES = [
+ "empty_like",
+ "full_like",
+ "ones_like",
+ "rand_like",
+ "randint_like",
+ "randn_like",
+ "zeros_like",
+]
+
+LIKE_FNS = [getattr(torch.ops.aten, name) for name in LIKE_NAMES]
+
+
+def _is_like_fn(fn):
+ return fn in LIKE_FNS
+
+
+def _apply_like_fn(func, *args, **kwargs):
+ result_data = func(_get_data(args[0]), *args[1:], **kwargs)
+ return MaskedTensor(result_data, _maybe_get_mask(args[0]))