Add MaskedTensor support to *_like API (#128637)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128637
Approved by: https://github.com/cpuhrsch
diff --git a/docs/source/masked.rst b/docs/source/masked.rst
index 60dd67f..ddaf127 100644
--- a/docs/source/masked.rst
+++ b/docs/source/masked.rst
@@ -277,12 +277,18 @@
     chunk
     column_stack
     dsplit
+    empty_like
     flatten
+    full_like
     hsplit
     hstack
     kron
     meshgrid
     narrow
+    ones_like
+    rand_like
+    randint_like
+    randn_like
     ravel
     select
     split
@@ -290,6 +296,7 @@
     transpose
     vsplit
     vstack
+    zeros_like
     Tensor.expand
     Tensor.expand_as
     Tensor.reshape
@@ -301,6 +308,7 @@
 .. py:module:: torch.masked.maskedtensor.binary
 .. py:module:: torch.masked.maskedtensor.core
 .. py:module:: torch.masked.maskedtensor.creation
+.. py:module:: torch.masked.maskedtensor.like
 .. py:module:: torch.masked.maskedtensor.passthrough
 .. py:module:: torch.masked.maskedtensor.reductions
 .. py:module:: torch.masked.maskedtensor.unary
\ No newline at end of file
diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py
index bf42cf0..0817822 100644
--- a/test/test_maskedtensor.py
+++ b/test/test_maskedtensor.py
@@ -19,6 +19,7 @@
     binary_ufuncs,
     reduction_ops,
     unary_ufuncs,
+    ops_and_refs,
 )
 
 from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
@@ -26,6 +27,7 @@
 from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
 from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
 from torch.masked.maskedtensor.reductions import REDUCE_NAMES
+from torch.masked.maskedtensor.like import LIKE_NAMES
 
 
 def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
@@ -812,9 +814,13 @@
 def is_reduction(op):
     return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"}
 
+def is_like(op):
+    return op.name in LIKE_NAMES
+
 mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)]
 mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)]
 mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)]
+mt_like_funcs = [op for op in ops_and_refs if is_like(op)]
 
 MASKEDTENSOR_FLOAT_TYPES = {
     torch.float16,
@@ -822,6 +828,16 @@
     torch.float64,
 }
 
+MASKEDTENSOR_SUPPORTED_TYPES = [
+    *MASKEDTENSOR_FLOAT_TYPES,
+    torch.bool,
+    torch.int8,
+    torch.int16,
+    torch.int32,
+    torch.int64,
+]
+
+
 class TestOperators(TestCase):
     def _convert_mt_args(self, args, mask, layout):
         return [
@@ -967,6 +983,43 @@
 
         self._test_reduction_equality(device, dtype, op, layout)
 
+    @ops(mt_like_funcs, allowed_dtypes=MASKEDTENSOR_SUPPORTED_TYPES)  # type: ignore[arg-type]
+    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
+    def test_like(self, device, dtype, op, layout):
+        samples = op.sample_inputs(device, dtype, requires_grad=dtype in MASKEDTENSOR_FLOAT_TYPES)
+
+        for sample in samples:
+            input = sample.input
+            sample_args, sample_kwargs = sample.args, sample.kwargs
+            mask = _create_random_mask(input.shape, device)
+
+            if layout == torch.sparse_coo:
+                mask = mask.to_sparse_coo().coalesce()
+                input = input.sparse_mask(mask)
+            elif layout == torch.sparse_csr:
+                if input.ndim != 2 or mask.ndim != 2:
+                    continue
+                mask = mask.to_sparse_csr()
+                input = input.sparse_mask(mask)
+
+            mt = masked_tensor(input, mask)
+
+            try:
+                t_result = op(input, *sample_args, **sample_kwargs)
+            except NotImplementedError:
+                self.skipTest(f"{op.name} is not supported for {layout}")
+
+            mt_result = op(mt, *sample_args, **sample_kwargs)
+
+            if 'rand' not in op.name and op.name != 'empty_like':
+                _compare_mt_t(mt_result, t_result.to_dense())
+
+            self.assertEqual(mt_result.shape, t_result.shape)
+            self.assertEqual(mt_result.dtype, t_result.dtype)
+            self.assertEqual(mt_result.layout, t_result.layout)
+            self.assertEqual(mt_result.device, t_result.device)
+            self.assertEqual(mt_result.requires_grad, t_result.requires_grad)
+
 
 only_for = ("cpu", "cuda")
 instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
diff --git a/torch/masked/maskedtensor/__init__.py b/torch/masked/maskedtensor/__init__.py
index e38e03c..9c908e9 100644
--- a/torch/masked/maskedtensor/__init__.py
+++ b/torch/masked/maskedtensor/__init__.py
@@ -3,6 +3,7 @@
 
 from .binary import _apply_native_binary, _is_native_binary
 from .core import is_masked_tensor, MaskedTensor
+from .like import _apply_like_fn, _is_like_fn
 from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
 from .reductions import _apply_reduction, _is_reduction
 from .unary import _apply_native_unary, _is_native_unary
diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py
index 7344a6e..febb0a9 100644
--- a/torch/masked/maskedtensor/_ops_refs.py
+++ b/torch/masked/maskedtensor/_ops_refs.py
@@ -14,6 +14,7 @@
     is_masked_tensor,
     MaskedTensor,
 )
+from .like import _apply_like_fn, LIKE_FNS
 from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
 from .reductions import (
     _apply_reduction,
@@ -270,6 +271,11 @@
     return _apply_native_binary(func, *args, **kwargs)
 
 
+@register_dispatch_func(LIKE_FNS)
+def _general_like(func, *args, **kwargs):
+    return _apply_like_fn(func, *args, **kwargs)
+
+
 @register_dispatch_func([torch.ops.aten.stride])
 def stride(func, *args, **kwargs):
     return None
@@ -365,13 +371,6 @@
     return MaskedTensor(result_data, mask)
 
 
-@register_dispatch_func([torch.ops.aten.ones_like])
-def ones_like(func, *args, **kwargs):
-    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
-    result_data = func(_get_data(args[0]), **kwargs)
-    return MaskedTensor(result_data, _maybe_get_mask(args[0]))
-
-
 @register_dispatch_func([torch.ops.aten._softmax_backward_data])
 def _softmax_backward_data(func, *args, **kwargs):
     _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
diff --git a/torch/masked/maskedtensor/like.py b/torch/masked/maskedtensor/like.py
new file mode 100644
index 0000000..3a410be
--- /dev/null
+++ b/torch/masked/maskedtensor/like.py
@@ -0,0 +1,31 @@
+# mypy: allow-untyped-defs
+"""
+These are the creation ops using the ``*_like`` API that need to support the
+`MaskedTensor`. The wrapper just applies the function to the masked data then convert
+it to a masked tensor using the mask from the given tensor.
+"""
+
+import torch
+from torch.masked.maskedtensor.core import _get_data, _maybe_get_mask, MaskedTensor
+
+
+LIKE_NAMES = [
+    "empty_like",
+    "full_like",
+    "ones_like",
+    "rand_like",
+    "randint_like",
+    "randn_like",
+    "zeros_like",
+]
+
+LIKE_FNS = [getattr(torch.ops.aten, name) for name in LIKE_NAMES]
+
+
+def _is_like_fn(fn):
+    return fn in LIKE_FNS
+
+
+def _apply_like_fn(func, *args, **kwargs):
+    result_data = func(_get_data(args[0]), *args[1:], **kwargs)
+    return MaskedTensor(result_data, _maybe_get_mask(args[0]))