[dtensor] PART 7: move remaining DTensor tests to core distributed (#88179)

This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed

part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88179
Approved by: https://github.com/aazzolini
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
new file mode 100644
index 0000000..22ae580
--- /dev/null
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -0,0 +1,704 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+import sys
+import unittest
+import warnings
+
+from torch.overrides import resolve_name
+from torch.utils._pytree import tree_flatten, tree_map
+from torch.testing._internal.common_utils import (
+    suppress_warnings,
+    TEST_WITH_ASAN,
+    run_tests,
+)
+import torch.distributed as dist
+from torch.testing._internal.common_device_type import (
+    ops,
+    instantiate_device_type_tests,
+)
+import torch.testing._internal.common_methods_invocations as common_ops
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+
+from torch.distributed._tensor import DTensor, DeviceMesh, Replicate
+from torch.testing._internal.distributed._tensor.dtensor_lagging_op_db import dtensor_lagging_op_db
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorTestBase,
+    TEST_SKIPS,
+    DTensorConverter,
+    DEVICE_TYPE,
+)
+
+# rewrite common size variables to sth can be sharded evenly
+# we can enable uneven shards later, but need to adjust more on
+# sample inputs (i.e. view/reshape need to adjust shape size as well)
+common_ops.L = 24
+common_ops.M = 12
+common_ops.S = 4
+common_ops.XS = 2
+
+
+def assert_ref_dtensor_equal(test_case, dtensor_rs, rs):
+    flat_dtensor_rs, _ = tree_flatten(dtensor_rs)
+    flat_rs, _ = tree_flatten(rs)
+    test_case.assertEqual(len(flat_dtensor_rs), len(flat_rs))
+    for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):
+
+        if not isinstance(r, torch.Tensor):
+            continue
+
+        test_case.assertIsInstance(dtensor_r, torch.Tensor)
+        test_case.assertEqual(
+            dtensor_r.shape,
+            r.shape,
+            f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
+        )
+        test_case.assertEqual(
+            dtensor_r.requires_grad,
+            r.requires_grad,
+            "op result requires_grad mismatch!"
+            f"original requires_grad: {r.requires_grad}, "
+            f"dtensor requires_grad: {dtensor_r.requires_grad}",
+        )
+
+        test_case.assertEqual(dtensor_r.to_local(), r)
+
+
+# Copied from functorch
+def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
+    return (op_name, variant_name, device_type, dtypes, True)
+
+
+def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
+    return (op_name, variant_name, device_type, dtypes, False)
+
+
+def skipOps(test_case_name, base_test_name, to_skip):
+    all_opinfos = dtensor_lagging_op_db
+    for xfail in to_skip:
+        op_name, variant_name, device_type, dtypes, expected_failure = xfail
+        matching_opinfos = [
+            o
+            for o in all_opinfos
+            if o.name == op_name and o.variant_test_name == variant_name
+        ]
+        assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
+        for opinfo in matching_opinfos:
+            decorators = list(opinfo.decorators)
+            if expected_failure:
+                decorator = DecorateInfo(
+                    unittest.expectedFailure,
+                    test_case_name,
+                    base_test_name,
+                    device_type=device_type,
+                    dtypes=dtypes,
+                )
+                decorators.append(decorator)
+            else:
+                decorator = DecorateInfo(
+                    unittest.skip("Skipped!"),
+                    test_case_name,
+                    base_test_name,
+                    device_type=device_type,
+                    dtypes=dtypes,
+                )
+                decorators.append(decorator)
+            opinfo.decorators = tuple(decorators)
+
+    # This decorator doesn't modify fn in any way
+    def wrapped(fn):
+        return fn
+
+    return wrapped
+
+
+# Re-generate this failed list, turn on dry_run of the below func
+# check_dtensor_func(self, test, op, dry_run=True), then run sth
+# like python test/spmd/tensor/test_dtensor_ops.py > failed.expect
+dtensor_fails = {
+    # these sometimes pass and sometimes fail
+    # we need to remove many of them from list once op
+    # get full support with varying sharding specs
+    xfail("__getitem__"),
+    xfail("__rsub__"),
+    xfail("masked.amax"),
+    xfail("masked.amin"),
+    xfail("masked.argmax"),
+    xfail("masked.argmin"),
+    xfail("masked.cumprod"),
+    xfail("masked.cumsum"),
+    xfail("masked.log_softmax"),
+    xfail("masked.logaddexp"),
+    xfail("masked.logsumexp"),
+    xfail("masked.median"),
+    xfail("masked.norm"),
+    xfail("masked.prod"),
+    xfail("masked.softmin"),
+    xfail("masked.softmax"),
+    xfail("masked.sum"),
+    xfail("addbmm"),
+    xfail("addmv"),
+    xfail("addr"),
+    xfail("all"),
+    xfail("allclose"),
+    xfail("amax"),
+    xfail("amin"),
+    xfail("aminmax"),
+    xfail("any"),
+    xfail("arange"),
+    xfail("argmax"),
+    xfail("argmin"),
+    xfail("argsort"),
+    xfail("as_strided"),
+    xfail("as_strided_scatter"),
+    xfail("baddbmm"),
+    xfail("bernoulli"),
+    xfail("block_diag"),
+    xfail("broadcast_shapes"),
+    xfail("cat"),
+    xfail("cartesian_prod"),
+    xfail("cdist"),
+    xfail("cholesky"),
+    xfail("cholesky_inverse"),
+    xfail("cholesky_solve"),
+    xfail("chunk"),
+    xfail("clamp"),
+    xfail("clamp_max"),
+    xfail("clamp_min"),
+    xfail("column_stack"),
+    xfail("combinations"),
+    xfail("complex"),
+    xfail("constant_pad_nd"),
+    xfail("copysign"),
+    xfail("corrcoef"),
+    xfail("count_nonzero"),
+    xfail("cov"),
+    xfail("cross"),
+    xfail("cummax"),
+    xfail("cummin"),
+    xfail("cumsum"),
+    xfail("cumulative_trapezoid"),
+    xfail("diag"),
+    xfail("diag_embed"),
+    xfail("diagflat"),
+    xfail("diagonal"),
+    xfail("diagonal_copy"),
+    xfail("diagonal_scatter"),
+    xfail("diff"),
+    xfail("dist"),
+    xfail("dot"),
+    xfail("dstack"),
+    xfail("einsum"),
+    xfail("empty"),
+    xfail("empty_like"),
+    xfail("eq"),
+    xfail("eye"),
+    xfail("fft.fft2"),
+    xfail("fft.fft"),
+    xfail("fft.fftn"),
+    xfail("fft.fftshift"),
+    xfail("fft.ifft2"),
+    xfail("fft.ifft"),
+    xfail("fft.ifftshift"),
+    xfail("fft.ihfft2"),
+    xfail("fft.ihfft"),
+    xfail("fft.ihfftn"),
+    xfail("fft.irfft2"),
+    xfail("fft.irfftn"),
+    xfail("fft.rfft2"),
+    xfail("fft.rfft"),
+    xfail("fft.rfftn"),
+    xfail("flip"),
+    xfail("fliplr"),
+    xfail("flipud"),
+    xfail("floor_divide"),
+    xfail("fmax"),
+    xfail("fmin"),
+    xfail("frexp"),
+    xfail("full"),
+    xfail("gather"),
+    xfail("geqrf"),
+    xfail("gradient"),
+    xfail("heaviside"),
+    xfail("histc"),
+    xfail("histogram"),
+    xfail("histogramdd"),
+    xfail("hstack"),
+    xfail("index_add"),
+    xfail("index_copy"),
+    xfail("index_fill"),
+    xfail("index_put"),
+    xfail("index_reduce"),
+    xfail("index_select"),
+    xfail("isfinite"),
+    xfail("isin"),
+    xfail("isinf"),
+    xfail("isnan"),
+    xfail("isneginf"),
+    xfail("isposinf"),
+    xfail("kthvalue"),
+    xfail("linalg.cholesky"),
+    xfail("linalg.cholesky_ex"),
+    xfail("linalg.cond"),
+    xfail("linalg.cross"),
+    xfail("linalg.det"),
+    xfail("linalg.det", "singular"),
+    xfail("linalg.eig"),
+    xfail("linalg.eigh"),
+    xfail("linalg.eigvals"),
+    xfail("linalg.eigvalsh"),
+    xfail("linalg.householder_product"),
+    xfail("linalg.inv"),
+    xfail("linalg.inv_ex"),
+    xfail("linalg.ldl_factor"),
+    xfail("linalg.ldl_factor_ex"),
+    xfail("linalg.ldl_solve"),
+    xfail("linalg.lstsq"),
+    xfail("linalg.lstsq", "grad_oriented"),
+    xfail("linalg.lu"),
+    xfail("linalg.lu_factor"),
+    xfail("linalg.lu_factor_ex"),
+    xfail("linalg.lu_solve"),
+    xfail("linalg.matrix_norm"),
+    xfail("linalg.matrix_power"),
+    xfail("linalg.matrix_rank"),
+    xfail("linalg.matrix_rank", "hermitian"),
+    xfail("linalg.multi_dot"),
+    xfail("linalg.norm"),
+    xfail("linalg.norm", "subgradients_at_zero"),
+    xfail("linalg.pinv"),
+    xfail("linalg.pinv", "hermitian"),
+    xfail("linalg.qr"),
+    xfail("linalg.slogdet"),
+    xfail("linalg.solve"),
+    xfail("linalg.solve_ex"),
+    xfail("linalg.solve_triangular"),
+    xfail("linalg.svd"),
+    xfail("linalg.svdvals"),
+    xfail("linalg.tensorinv"),
+    xfail("linalg.tensorsolve"),
+    xfail("linalg.vander"),
+    xfail("linalg.vecdot"),
+    xfail("linalg.vector_norm"),
+    xfail("linspace"),
+    xfail("log_softmax"),
+    xfail("log_softmax", "with_dtype"),
+    xfail("logcumsumexp"),
+    xfail("logdet"),
+    xfail("logical_not"),
+    xfail("logspace"),
+    xfail("logsumexp"),
+    xfail("lt"),
+    xfail("lu"),
+    xfail("lu_solve"),
+    xfail("lu_unpack"),
+    xfail("masked_fill"),
+    xfail("masked_scatter"),
+    xfail("masked_select"),
+    xfail("matrix_exp"),
+    xfail("max", "binary"),
+    xfail("max", "reduction_no_dim"),
+    xfail("max", "reduction_with_dim"),
+    xfail("maximum"),
+    xfail("median"),
+    xfail("min", "binary"),
+    xfail("min", "reduction_no_dim"),
+    xfail("min", "reduction_with_dim"),
+    xfail("minimum"),
+    xfail("mode"),
+    xfail("msort"),
+    xfail("multinomial"),
+    xfail("mv"),
+    xfail("max_pool2d_with_indices_backward", ""),
+    xfail("nanmean"),
+    xfail("nanmedian"),
+    xfail("nanquantile"),
+    xfail("nansum"),
+    xfail("native_batch_norm"),
+    xfail("native_layer_norm"),
+    xfail("narrow_copy"),
+    xfail("ne"),
+    xfail("new_empty"),
+    xfail("new_empty_strided"),
+    xfail("transpose"),
+    xfail("nn.functional.adaptive_avg_pool1d"),
+    xfail("nn.functional.adaptive_avg_pool2d"),
+    xfail("nn.functional.adaptive_avg_pool3d"),
+    xfail("nn.functional.adaptive_max_pool1d"),
+    xfail("nn.functional.adaptive_max_pool2d"),
+    xfail("nn.functional.adaptive_max_pool3d"),
+    xfail("nn.functional.alpha_dropout"),
+    xfail("nn.functional.avg_pool1d"),
+    xfail("nn.functional.avg_pool2d"),
+    xfail("nn.functional.avg_pool3d"),
+    xfail("nn.functional.batch_norm"),
+    xfail("nn.functional.batch_norm", "without_cudnn"),
+    xfail("nn.functional.bilinear"),
+    xfail("nn.functional.binary_cross_entropy"),
+    xfail("nn.functional.binary_cross_entropy_with_logits"),
+    xfail("nn.functional.celu"),
+    xfail("nn.functional.conv1d"),
+    xfail("nn.functional.conv2d"),
+    xfail("nn.functional.conv_transpose1d"),
+    xfail("nn.functional.conv_transpose2d"),
+    xfail("nn.functional.conv_transpose3d"),
+    xfail("nn.functional.cosine_similarity"),
+    xfail("nn.functional.cross_entropy"),
+    xfail("nn.functional.ctc_loss"),
+    xfail("nn.functional.dropout"),
+    xfail("nn.functional.dropout2d"),
+    xfail("nn.functional.dropout3d"),
+    xfail("nn.functional.elu"),
+    xfail("nn.functional.fractional_max_pool2d"),
+    xfail("nn.functional.fractional_max_pool3d"),
+    xfail("nn.functional.gaussian_nll_loss"),
+    xfail("nn.functional.glu"),
+    xfail("nn.functional.grid_sample"),
+    xfail("nn.functional.group_norm"),
+    xfail("nn.functional.hardshrink"),
+    xfail("nn.functional.hardsigmoid"),
+    xfail("nn.functional.hardswish"),
+    xfail("nn.functional.hardtanh"),
+    xfail("nn.functional.huber_loss"),
+    xfail("nn.functional.instance_norm"),
+    xfail("nn.functional.interpolate", "area"),
+    xfail("nn.functional.interpolate", "bicubic"),
+    xfail("nn.functional.interpolate", "bilinear"),
+    xfail("nn.functional.interpolate", "linear"),
+    xfail("nn.functional.interpolate", "nearest"),
+    xfail("nn.functional.interpolate", "trilinear"),
+    xfail("nn.functional.layer_norm"),
+    xfail("nn.functional.leaky_relu"),
+    xfail("nn.functional.linear"),
+    xfail("nn.functional.local_response_norm"),
+    xfail("nn.functional.logsigmoid"),
+    xfail("nn.functional.margin_ranking_loss"),
+    xfail("nn.functional.max_pool1d"),
+    xfail("nn.functional.max_pool2d"),
+    xfail("nn.functional.max_pool3d"),
+    xfail("nn.functional.max_unpool1d"),
+    xfail("nn.functional.max_unpool1d", "grad"),
+    xfail("nn.functional.max_unpool2d"),
+    xfail("nn.functional.max_unpool2d", "grad"),
+    xfail("nn.functional.max_unpool3d"),
+    xfail("nn.functional.max_unpool3d", "grad"),
+    xfail("nn.functional.mish"),
+    xfail("nn.functional.mse_loss"),
+    xfail("nn.functional.multi_margin_loss"),
+    xfail("nn.functional.multilabel_margin_loss"),
+    xfail("nn.functional.multilabel_soft_margin_loss"),
+    xfail("nn.functional.nll_loss"),
+    xfail("nn.functional.normalize"),
+    xfail("nn.functional.pad", "circular"),
+    xfail("nn.functional.pad", "constant"),
+    xfail("nn.functional.pad", "reflect"),
+    xfail("nn.functional.pad", "replicate"),
+    xfail("nn.functional.pairwise_distance"),
+    xfail("nn.functional.pdist"),
+    xfail("nn.functional.pixel_shuffle"),
+    xfail("nn.functional.pixel_unshuffle"),
+    xfail("nn.functional.poisson_nll_loss"),
+    xfail("nn.functional.prelu"),
+    xfail("nn.functional.relu6"),
+    xfail("nn.functional.rrelu"),
+    xfail("nn.functional.selu"),
+    xfail("nn.functional.silu"),
+    xfail("nn.functional.smooth_l1_loss"),
+    xfail("nn.functional.soft_margin_loss"),
+    xfail("nn.functional.softplus"),
+    xfail("nn.functional.softshrink"),
+    xfail("nn.functional.threshold"),
+    xfail("nn.functional.triplet_margin_loss"),
+    xfail("nn.functional.triplet_margin_with_distance_loss"),
+    xfail("nn.functional.unfold"),
+    xfail("nn.functional.upsample_bilinear"),
+    xfail("nn.functional.upsample_nearest"),
+    xfail("nonzero"),
+    xfail("norm"),
+    xfail("norm", "fro"),
+    xfail("norm", "inf"),
+    xfail("norm", "nuc"),
+    xfail("normal"),
+    xfail("normal", "number_mean"),
+    xfail("ormqr"),
+    xfail("ones"),
+    xfail("pca_lowrank"),
+    xfail("pinverse"),
+    xfail("polar"),
+    xfail("put"),
+    xfail("qr"),
+    xfail("quantile"),
+    xfail("rad2deg"),
+    xfail("rand_like"),
+    xfail("randint_like"),
+    xfail("randint"),
+    xfail("randn"),
+    xfail("randn_like"),
+    xfail("renorm"),
+    xfail("repeat_interleave"),
+    xfail("resize_"),
+    xfail("resize_as_"),
+    xfail("roll"),
+    xfail("rot90"),
+    xfail("rsub"),
+    xfail("scalar_tensor"),
+    xfail("scatter_add"),
+    xfail("scatter"),
+    xfail("scatter_reduce", "amax"),
+    xfail("scatter_reduce", "amin"),
+    xfail("scatter_reduce", "mean"),
+    xfail("scatter_reduce", "prod"),
+    xfail("scatter_reduce", "sum"),
+    xfail("searchsorted"),
+    xfail("select"),
+    xfail("select_scatter"),
+    xfail("signbit"),
+    xfail("sort"),
+    xfail("sparse.sampled_addmm"),
+    xfail("special.airy_ai"),
+    xfail("special.bessel_j0"),
+    xfail("special.bessel_j1"),
+    xfail("special.bessel_y0"),
+    xfail("special.bessel_y1"),
+    xfail("special.chebyshev_polynomial_t"),
+    xfail("special.chebyshev_polynomial_u"),
+    xfail("special.entr"),
+    xfail("special.erfcx"),
+    xfail("special.hermite_polynomial_h"),
+    xfail("special.hermite_polynomial_he"),
+    xfail("special.i0e"),
+    xfail("special.i1"),
+    xfail("special.i1e"),
+    xfail("special.laguerre_polynomial_l"),
+    xfail("special.log_ndtr"),
+    xfail("special.modified_bessel_i0"),
+    xfail("special.modified_bessel_i1"),
+    xfail("special.modified_bessel_k0"),
+    xfail("special.modified_bessel_k1"),
+    xfail("special.ndtri"),
+    xfail("special.scaled_modified_bessel_k0"),
+    xfail("special.scaled_modified_bessel_k1"),
+    xfail("special.spherical_bessel_j0"),
+    xfail("special.xlog1py"),
+    xfail("special.zeta"),
+    xfail("split"),
+    xfail("split", "list_args"),
+    xfail("split_with_sizes"),
+    xfail("signal.windows.cosine"),
+    xfail("signal.windows.exponential"),
+    xfail("signal.windows.gaussian"),
+    xfail("signal.windows.kaiser"),
+    xfail("squeeze"),
+    xfail("stack"),
+    xfail("std"),
+    xfail("std_mean"),
+    xfail("stft"),
+    xfail("svd"),
+    xfail("svd_lowrank"),
+    xfail("symeig"),
+    xfail("t"),
+    xfail("take_along_dim"),
+    xfail("take"),
+    xfail("tensor_split"),
+    xfail("to_sparse"),
+    xfail("topk"),
+    xfail("trace"),
+    xfail("trapezoid"),
+    xfail("trapz"),
+    xfail("triangular_solve"),
+    xfail("tril"),
+    xfail("triu"),
+    xfail("unbind"),
+    xfail("unfold"),
+    xfail("unfold_copy"),
+    xfail("uniform"),
+    xfail("unflatten"),
+    xfail("unique_consecutive"),
+    xfail("unique"),
+    xfail("var_mean"),
+    xfail("vdot"),
+    xfail("view_as_complex"),
+    xfail("vstack"),
+    xfail("zeros"),
+    # ops inside this might even fail without dtensor
+    # tests, as we rescale op db common test size factor (i.e. L, M, S)
+    # which triggered the orignal function run failures with input
+    # generation becomes wrong, we skip them for now but should enable later.
+    # TODO: need to clean this list and remove all cases
+    skip("argwhere"),
+    skip("cumprod"),
+    skip("__rmatmul__"),
+    skip("meshgrid", "list_of_tensors"),
+    skip("meshgrid", "variadic_tensors"),
+    skip("nn.functional._scaled_dot_product_attention"),
+    skip("nn.functional.softmin"),
+    skip("nn.functional.embedding"),
+    skip("nn.functional.embedding_bag"),
+    skip("nn.functional.feature_alpha_dropout", "with_train"),
+    skip("nn.functional.feature_alpha_dropout", "without_train"),
+    skip("nn.functional.hinge_embedding_loss"),
+    skip("nn.functional.cosine_embedding_loss"),
+    skip("fft.hfft"),
+    skip("fft.hfft2"),
+    skip("fft.hfft2"),
+    skip("fft.hfftn"),
+    skip("fft.ifftn"),
+    skip("fft.irfft"),
+    skip("istft"),
+    skip("isclose"),
+    skip("isreal"),
+    skip("matmul"),
+    skip("masked.mean"),
+    skip("masked.var"),
+    skip("masked.std"),
+    skip("masked.normalize"),
+    skip("prod"),
+    skip("segment_reduce", "lengths"),
+    skip("segment_reduce", "offsets"),
+}
+
+
+# Add a list of ops that are currently failing BW pass
+skip_bw = [
+    None,  # corresponds to the transpose ops 'H' and 'T'
+    "torch.bucketize",
+    "torch.conj_physical",
+    "torch.eq",
+    "torch.isfinite",
+    "torch.isnan",
+]
+
+
+def run_dtensor_crossref(test_case, func, args, kwargs):
+    to_dtensor = DTensorConverter(test_case.mesh, args, kwargs)
+
+    # TODO: also handle cases where func raise an exception
+    rs = func(*args, **kwargs)
+
+    def to_replicate(e: object) -> object:
+        return (
+            e.redistribute(test_case.mesh, test_case.mesh.ndim * [Replicate()])
+            if isinstance(e, DTensor)
+            else e
+        )
+
+    try:
+        # Suppress warnings, this doesn't matter for test_meta.py
+        # but it does matter if you want to use this decorator
+        # for cross-ref testing, as some tests may be looking at
+        # errors
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            # for every comb of sharding choices, we test if it works
+            for dtensor_args, dtensor_kwargs in to_dtensor:
+                # Only attempt if we managed to convert all tensors to DTensor
+                # (if any of them failed, we're in a mixed tensor situation and
+                # this is not allowed in DTensor)
+                if to_dtensor.successful():
+                    # Handle special cases first if there's any
+                    # Suppress warnings, this doesn't matter for test_meta.py
+                    # but it does matter if you want to use this decorator
+                    # for cross-ref testing, as some tests may be looking at
+                    # errors
+                    dtensor_rs = func(*dtensor_args, **dtensor_kwargs)
+
+                    # we need to skip tests containing tensors of zero elmeents for now.
+                    # see issue: https://github.com/pytorch/tau/issues/470
+                    # TODO remove this once issue above fixed.
+                    flat_args, _ = tree_flatten(dtensor_rs)
+                    if any(
+                        isinstance(e, torch.Tensor) and e.numel() == 0
+                        for e in flat_args
+                    ):
+                        continue
+
+                    # redistribute/all_gather the results to compare with normal output
+                    dtensor_rs = tree_map(to_replicate, dtensor_rs)
+                    try:
+                        if resolve_name(func) not in skip_bw:
+                            if isinstance(dtensor_rs, DTensor):
+                                dtensor_rs.to_local().sum().backward()
+                            elif isinstance(dtensor_rs, tuple):
+                                dtensor_rs[0].to_local().sum().backward()
+
+                    except Exception as e:
+                        # TODO(anj): Remove this guard exception after gaining more confidence.
+                        if torch.distributed.get_rank() == 0:
+                            print(
+                                f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
+                            )
+                    assert_ref_dtensor_equal(test_case, dtensor_rs, rs)
+                else:
+                    raise RuntimeError(
+                        f"failed to convert args to DTensor; "
+                        f"originally (*{args}, **{kwargs})"
+                    )
+    except Exception as e:
+        raise RuntimeError(
+            f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})"
+        ) from e
+
+    return rs
+
+
+def check_dtensor_func(test_case, test_func, opinfo, dry_run=False):
+    try:
+        test_func()
+    except Exception:
+        test_case.destroy_pg()
+        if not dry_run:
+            raise
+        if dist.get_rank() == 0:
+            if opinfo.variant_test_name:
+                print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
+            else:
+                print(f"xfail('{opinfo.name}'),")
+    else:
+        test_case.destroy_pg()
+
+
+class TestDTensorOps(DTensorTestBase):
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    # only allow float dytpe for now, we can relax this constraint
+    # when feel necessary later (i.e when adding quantization support).
+    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
+    @suppress_warnings
+    @ops(dtensor_lagging_op_db, allowed_dtypes=(torch.float,))
+    @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
+    def test_dtensor_op_db(self, dtype, op):
+        pg_backend = "nccl" if DEVICE_TYPE == "cuda" else "gloo"
+        if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        self.init_pg(backend=pg_backend)
+        self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
+
+        # test each op with dist tensor inputs and normal inputs
+        def test():
+            samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True)
+            for sample_input in samples:
+                args = [sample_input.input] + list(sample_input.args)
+                kwargs = sample_input.kwargs
+
+                run_dtensor_crossref(self, op.op, args, kwargs)
+                # we need to figure out a way to test the out variant, out variant testing
+                # is tricky, as we need to pre allocate the dtensor out, some of them rely
+                # on sharding placements to be pre-known (i.e. mm.out)
+                # if isinstance(expected, torch.Tensor) and op.supports_out:
+                #     func(*args, **kwargs, out=expected)
+
+        check_dtensor_func(self, test, op)
+
+
+# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
+instantiate_device_type_tests(
+    TestDTensorOps, globals(), only_for=(DEVICE_TYPE,)
+)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py
new file mode 100644
index 0000000..1ba3f6d
--- /dev/null
+++ b/test/distributed/_tensor/test_tensor_ops.py
@@ -0,0 +1,365 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorConverter,
+    DTensorTestBase,
+    with_comms,
+)
+from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor
+from torch.distributed._tensor.placement_types import Shard, Replicate, _Partial
+
+
+class DistTensorOpsTest(DTensorTestBase):
+    @with_comms
+    def test_aten_contiguous(self):
+        # this op not covered by dtensor_ops
+        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        self._test_op(
+            mesh,
+            lambda x: torch.ops.aten.contiguous(x),
+            torch.randn(16, 32),
+        )
+
+    @with_comms
+    def test_detach(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        tensor_to_detach = torch.randn(12, 8, requires_grad=True)
+        mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec)
+        detached_mat = mat.detach()
+        self.assertFalse(detached_mat is mat)
+
+    @with_comms
+    def test_clone(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        specs = [[Replicate()], [Shard(0)]]
+        tensor_to_clone = torch.randn(12, 8, requires_grad=True)
+        for spec in specs:
+            mat = distribute_tensor(tensor_to_clone, device_mesh, spec)
+            cloned_mat = mat.clone()
+            self.assertFalse(cloned_mat is mat)
+            self.assertEqual(cloned_mat.to_local(), mat.to_local())
+
+    @with_comms
+    def test_contiguous(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        tensor = torch.rand(3, 5, 6, requires_grad=True)
+        sharding = [Shard(0)]
+        dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
+        self.assertTrue(dist_tensor.is_contiguous())
+        # shard on dim 0 should not change stride (30, 6, 1)
+        self.assertEqual(dist_tensor.stride(), tensor.stride())
+
+        new_dt = dist_tensor.transpose(0, 2)
+        self.assertFalse(new_dt.is_contiguous())
+        self.assertFalse(new_dt.to_local().is_contiguous())
+        # check stride
+        self.assertEqual(new_dt.stride(), (1, 6, 30))
+
+        new_dt = new_dt.contiguous()
+        self.assertTrue(new_dt.is_contiguous())
+        self.assertTrue(new_dt.to_local().is_contiguous())
+        # check stride
+        self.assertEqual(dist_tensor.stride(), tensor.stride())
+
+        # check backward
+        new_dt.to_local().sum().backward()
+        self.assertEqual(tensor.grad, torch.ones(3, 5, 6))
+
+    @with_comms
+    def test_inplace_op(self):
+        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        input_tensor = torch.randn((12, 3), device=self.device_type)
+        dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)])
+        dt_to_mul = dt_to_add.clone()
+        expected_add_dt = dt_to_add.clone() + 3
+        add_res = dt_to_add.add_(3)
+        expected_mul_dt = dt_to_mul.clone() * 3
+        mul_res = dt_to_mul.mul_(3)
+        # inplace op should be the same instance before and after
+        self.assertTrue(add_res is dt_to_add)
+        self.assertEqual(add_res.to_local(), expected_add_dt.to_local())
+
+        self.assertTrue(mul_res is dt_to_mul)
+        self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local())
+
+        # test inplace op self and other dtensor with other specs
+        # and make sure out spec not change
+        shard_spec = [Shard(0)]
+        partial_spec = [_Partial()]
+        dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec)
+        partial_grad = DTensor.from_local(
+            torch.randn(12, 3), mesh, partial_spec
+        )
+        res = dt_to_inplace_add.add_(partial_grad)
+        self.assertTrue(res is dt_to_inplace_add)
+        self.assertTrue(res.placements == shard_spec)
+
+    @with_comms
+    def test_op_out_variant(self):
+        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        input_tensor = torch.randn((12, 3), device=self.device_type)
+        sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)])
+        expected_dt = sharded_dt_input.clone() + 3
+        sharded_dt_out = sharded_dt_input.clone()
+        res = torch.add(sharded_dt_input, 3, out=sharded_dt_out)
+        # op out variant should be the same instance before and after
+        self.assertTrue(res is sharded_dt_out)
+        self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local())
+
+        # test op out variant with other spec and make sure out spec not change
+        replica_spec = [Replicate()]
+        replicate_out = distribute_tensor(input_tensor, mesh, replica_spec)
+        expected_dt = replicate_out.clone() + 3
+        res = torch.add(sharded_dt_input, 3, out=replicate_out)
+        self.assertTrue(res is replicate_out)
+        self.assertTrue(res.placements == replica_spec)
+        self.assertEqual(replicate_out.to_local(), expected_dt.to_local())
+
+    @with_comms
+    def test_empty_like(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        empty_like_dt = torch.empty_like(dist_tensor)
+        # empty is not deterministic, so we only check that the shard propagation worked
+        self.assertEqual((4, 8), empty_like_dt.to_local().shape)
+
+    @with_comms
+    def test_fill_inplace(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        full_like_dt = torch.fill_(dist_tensor, 42.0)
+        full_expected = torch.full((4, 8), 42.0)
+        self.assertEqual(full_expected, full_like_dt.to_local())
+        self.assertEqual(full_expected, dist_tensor.to_local())
+
+    @with_comms
+    def test_full_like(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        full_like_dt = torch.full_like(dist_tensor, 42.0)
+        full_expected = torch.full((4, 8), 42.0)
+        self.assertEqual(full_expected, full_like_dt.to_local())
+
+    @with_comms
+    def test_ones_like(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        ones_like_dt = torch.ones_like(dist_tensor)
+        ones_expected = torch.ones(4, 8)
+        self.assertEqual(ones_expected, ones_like_dt.to_local())
+
+    @with_comms
+    def test_ones_like_partial_sum(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [_Partial()]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        assert dist_tensor.shape == (4, 8)
+
+        ones_like_dt = torch.ones_like(dist_tensor)
+        ones_expected = torch.ones(dist_tensor.shape)
+        self.assertEqual(
+            ones_expected,
+            ones_like_dt.redistribute(device_mesh, [Replicate()]).to_local(),
+        )
+
+    @with_comms
+    def test_fill_inplace_partial_sum(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [_Partial()]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        assert dist_tensor.shape == (4, 8)
+
+        torch.fill_(dist_tensor, 42)
+        fill_expected = torch.full(
+            dist_tensor.shape, 42, dtype=input_tensor.dtype
+        )
+        self.assertEqual(
+            fill_expected,
+            dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(),
+        )
+
+    @with_comms
+    def test_zeros_like_partial_sum(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [_Partial()]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        assert dist_tensor.shape == (4, 8)
+
+        zeros_like_dt = torch.zeros_like(dist_tensor)
+        zeros_expected = torch.zeros(dist_tensor.shape)
+        self.assertEqual(
+            zeros_expected,
+            zeros_like_dt.redistribute(device_mesh, [Replicate()]).to_local(),
+        )
+
+    @with_comms
+    def test_zero_inplace(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        zeros_like_dt = torch.zero_(dist_tensor)
+        zeros_expected = torch.zeros(4, 8)
+        self.assertEqual(zeros_expected, zeros_like_dt.to_local())
+        self.assertEqual(zeros_expected, dist_tensor.to_local())
+
+    @with_comms
+    def test_zeros_like(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        input_tensor = torch.randn(4, 8, requires_grad=True)
+        dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        zeros_like_dt = torch.zeros_like(dist_tensor)
+        zeros_expected = torch.zeros(4, 8)
+        self.assertEqual(zeros_expected, zeros_like_dt.to_local())
+
+    def _test_op(self, mesh, op_call, *args, **kwargs):
+        out = op_call(*args, **kwargs)
+        dtc = DTensorConverter(mesh, args, kwargs)
+        for d_args, d_kwargs in dtc:
+            self.assertTrue(dtc.successful())
+            d_out = op_call(*d_args, **d_kwargs)
+            self.assertEqual(
+                d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(),
+                out,
+            )
+
+    @with_comms
+    def test_index(self):
+        meshes = [
+            DeviceMesh(
+                self.device_type, list(range(self.world_size))
+            ),  # 1D mesh
+            # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh
+            # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh
+        ]
+        for mesh in meshes:
+            self._test_op(
+                mesh,
+                lambda x, y: x[y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4, 8)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x.index_select(1, y),
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4,)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x.index_select(0, y),
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4,)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x[y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (12,)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x[:, y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4, 8)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x[..., y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y: x[..., y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (4, 8, 16)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[z, y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (12, 8, 12)),
+                torch.randint(2, (12, 8, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[z, :, y],
+                torch.randn(16, 32, 16),
+                torch.randint(5, (12, 8, 12)),
+                torch.randint(2, (12, 8, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[:, z, :, y],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(5, (12, 8, 12)),
+                torch.randint(2, (12, 8, 12)),
+            )
+            # broadcast in inner dimensions
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[:, z, :, y],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(5, (12, 8, 12)),
+                torch.randint(2, (12, 1, 12)),
+            )
+            # implicit (left-padded) broadcast
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[:, z, :, y],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(5, (12, 8, 12)),
+                torch.randint(2, (8, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[z, y, :, :],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(2, (8, 12)),
+                torch.randint(5, (12, 8, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[z, :, y, :],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(2, (8, 12)),
+                torch.randint(5, (12, 8, 12)),
+            )
+            self._test_op(
+                mesh,
+                lambda x, y, z: x[z, :, :, y],
+                torch.randn(16, 32, 16, 12),
+                torch.randint(2, (8, 1)),
+                torch.randint(5, (12, 8, 12)),
+            )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py
new file mode 100644
index 0000000..abd0ccf
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py
@@ -0,0 +1,661 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import List
+from torch.testing._internal.common_methods_invocations import op_db, OpInfo
+
+# Generated from test/gen_dtensor_op_db.py via
+# python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
+#
+# This approach is copied from functorch:
+# People add new OpInfos to PyTorch all the time.
+# We want them to be able to add OpInfos without breaking our CI.
+# To achieve this, we keep our OpInfo library behind that of Pytorch's and
+# we periodically update our OpInfo library by regenerating this file
+_dtensor_lagging_meta = {
+    ("H", ""),
+    ("T", ""),
+    ("__getitem__", ""),
+    ("__radd__", ""),
+    ("__rand__", ""),
+    ("__rdiv__", ""),
+    ("__rmatmul__", ""),
+    ("__rmod__", ""),
+    ("__rmul__", ""),
+    ("__ror__", ""),
+    ("__rpow__", ""),
+    ("__rsub__", ""),
+    ("__rxor__", ""),
+    ("abs", ""),
+    ("acos", ""),
+    ("acosh", ""),
+    ("add", ""),
+    ("addbmm", ""),
+    ("addcdiv", ""),
+    ("addcmul", ""),
+    ("addmm", ""),
+    ("addmm", "decomposed"),
+    ("addmv", ""),
+    ("addr", ""),
+    ("all", ""),
+    ("allclose", ""),
+    ("amax", ""),
+    ("amin", ""),
+    ("aminmax", ""),
+    ("angle", ""),
+    ("any", ""),
+    ("arange", ""),
+    ("argmax", ""),
+    ("argmin", ""),
+    ("argsort", ""),
+    ("argwhere", ""),
+    ("as_strided", ""),
+    ("as_strided_scatter", ""),
+    ("asin", ""),
+    ("asinh", ""),
+    ("atan", ""),
+    ("atan2", ""),
+    ("atanh", ""),
+    ("atleast_1d", ""),
+    ("atleast_2d", ""),
+    ("atleast_3d", ""),
+    ("baddbmm", ""),
+    ("bernoulli", ""),
+    ("bfloat16", ""),
+    ("bincount", ""),
+    ("bitwise_and", ""),
+    ("bitwise_left_shift", ""),
+    ("bitwise_not", ""),
+    ("bitwise_or", ""),
+    ("bitwise_right_shift", ""),
+    ("bitwise_xor", ""),
+    ("block_diag", ""),
+    ("bmm", ""),
+    ("bool", ""),
+    ("broadcast_shapes", ""),
+    ("broadcast_tensors", ""),
+    ("broadcast_to", ""),
+    ("bucketize", ""),
+    ("byte", ""),
+    ("cartesian_prod", ""),
+    ("cat", ""),
+    ("cdist", ""),
+    ("cdouble", ""),
+    ("ceil", ""),
+    ("cfloat", ""),
+    ("chalf", ""),
+    ("char", ""),
+    ("cholesky", ""),
+    ("cholesky_inverse", ""),
+    ("cholesky_solve", ""),
+    ("chunk", ""),
+    ("clamp", ""),
+    ("clamp_max", ""),
+    ("clamp_min", ""),
+    ("clone", ""),
+    ("column_stack", ""),
+    ("combinations", ""),
+    ("complex", ""),
+    ("conj", ""),
+    ("conj_physical", ""),
+    ("constant_pad_nd", ""),
+    ("contiguous", ""),
+    ("copysign", ""),
+    ("corrcoef", ""),
+    ("cos", ""),
+    ("cosh", ""),
+    ("count_nonzero", ""),
+    ("cov", ""),
+    ("cross", ""),
+    ("cummax", ""),
+    ("cummin", ""),
+    ("cumprod", ""),
+    ("cumsum", ""),
+    ("cumulative_trapezoid", ""),
+    ("deg2rad", ""),
+    ("diag", ""),
+    ("diag_embed", ""),
+    ("diagflat", ""),
+    ("diagonal", ""),
+    ("diagonal_copy", ""),
+    ("diagonal_scatter", ""),
+    ("diff", ""),
+    ("digamma", ""),
+    ("dist", ""),
+    ("div", "floor_rounding"),
+    ("div", "no_rounding_mode"),
+    ("div", "trunc_rounding"),
+    ("dot", ""),
+    ("double", ""),
+    ("dsplit", ""),
+    ("dstack", ""),
+    ("einsum", ""),
+    ("empty", ""),
+    ("empty_like", ""),
+    ("eq", ""),
+    ("equal", ""),
+    ("erf", ""),
+    ("erfc", ""),
+    ("erfinv", ""),
+    ("exp", ""),
+    ("exp2", ""),
+    ("expand", ""),
+    ("expand_as", ""),
+    ("expm1", ""),
+    ("eye", ""),
+    ("fft.fft", ""),
+    ("fft.fft2", ""),
+    ("fft.fftn", ""),
+    ("fft.fftshift", ""),
+    ("fft.hfft", ""),
+    ("fft.hfft2", ""),
+    ("fft.hfftn", ""),
+    ("fft.ifft", ""),
+    ("fft.ifft2", ""),
+    ("fft.ifftn", ""),
+    ("fft.ifftshift", ""),
+    ("fft.ihfft", ""),
+    ("fft.ihfft2", ""),
+    ("fft.ihfftn", ""),
+    ("fft.irfft", ""),
+    ("fft.irfft2", ""),
+    ("fft.irfftn", ""),
+    ("fft.rfft", ""),
+    ("fft.rfft2", ""),
+    ("fft.rfftn", ""),
+    ("fill", ""),
+    ("flatten", ""),
+    ("flip", ""),
+    ("fliplr", ""),
+    ("flipud", ""),
+    ("float", ""),
+    ("float_power", ""),
+    ("floor", ""),
+    ("floor_divide", ""),
+    ("fmax", ""),
+    ("fmin", ""),
+    ("fmod", ""),
+    ("frac", ""),
+    ("frexp", ""),
+    ("full", ""),
+    ("full_like", ""),
+    ("gather", ""),
+    ("gcd", ""),
+    ("ge", ""),
+    ("geqrf", ""),
+    ("gradient", ""),
+    ("gt", ""),
+    ("half", ""),
+    ("heaviside", ""),
+    ("histc", ""),
+    ("histogram", ""),
+    ("histogramdd", ""),
+    ("hsplit", ""),
+    ("hstack", ""),
+    ("hypot", ""),
+    ("i0", ""),
+    ("igamma", ""),
+    ("igammac", ""),
+    ("imag", ""),
+    ("index_add", ""),
+    ("index_copy", ""),
+    ("index_fill", ""),
+    ("index_put", ""),
+    ("index_reduce", ""),
+    ("index_select", ""),
+    ("inner", ""),
+    ("int", ""),
+    ("isclose", ""),
+    ("isfinite", ""),
+    ("isin", ""),
+    ("isinf", ""),
+    ("isnan", ""),
+    ("isneginf", ""),
+    ("isposinf", ""),
+    ("isreal", ""),
+    ("istft", ""),
+    ("jiterator_2inputs_2outputs", ""),
+    ("jiterator_4inputs_with_extra_args", ""),
+    ("jiterator_binary", ""),
+    ("jiterator_binary_return_by_ref", ""),
+    ("jiterator_unary", ""),
+    ("kron", ""),
+    ("kthvalue", ""),
+    ("lcm", ""),
+    ("ldexp", ""),
+    ("le", ""),
+    ("lerp", ""),
+    ("lgamma", ""),
+    ("linalg.cholesky", ""),
+    ("linalg.cholesky_ex", ""),
+    ("linalg.cond", ""),
+    ("linalg.cross", ""),
+    ("linalg.det", ""),
+    ("linalg.det", "singular"),
+    ("linalg.eig", ""),
+    ("linalg.eigh", ""),
+    ("linalg.eigvals", ""),
+    ("linalg.eigvalsh", ""),
+    ("linalg.householder_product", ""),
+    ("linalg.inv", ""),
+    ("linalg.inv_ex", ""),
+    ("linalg.ldl_factor", ""),
+    ("linalg.ldl_factor_ex", ""),
+    ("linalg.ldl_solve", ""),
+    ("linalg.lstsq", ""),
+    ("linalg.lstsq", "grad_oriented"),
+    ("linalg.lu", ""),
+    ("linalg.lu_factor", ""),
+    ("linalg.lu_factor_ex", ""),
+    ("linalg.lu_solve", ""),
+    ("linalg.matrix_norm", ""),
+    ("linalg.matrix_power", ""),
+    ("linalg.matrix_rank", ""),
+    ("linalg.matrix_rank", "hermitian"),
+    ("linalg.multi_dot", ""),
+    ("linalg.norm", ""),
+    ("linalg.norm", "subgradients_at_zero"),
+    ("linalg.pinv", ""),
+    ("linalg.pinv", "hermitian"),
+    ("linalg.pinv", "singular"),
+    ("linalg.qr", ""),
+    ("linalg.slogdet", ""),
+    ("linalg.solve", ""),
+    ("linalg.solve_ex", ""),
+    ("linalg.solve_triangular", ""),
+    ("linalg.svd", ""),
+    ("linalg.svdvals", ""),
+    ("linalg.tensorinv", ""),
+    ("linalg.tensorsolve", ""),
+    ("linalg.vander", ""),
+    ("linalg.vecdot", ""),
+    ("linalg.vector_norm", ""),
+    ("linspace", ""),
+    ("log", ""),
+    ("log10", ""),
+    ("log1p", ""),
+    ("log2", ""),
+    ("log_softmax", ""),
+    ("log_softmax", "with_dtype"),
+    ("logaddexp", ""),
+    ("logaddexp2", ""),
+    ("logcumsumexp", ""),
+    ("logdet", ""),
+    ("logical_and", ""),
+    ("logical_not", ""),
+    ("logical_or", ""),
+    ("logical_xor", ""),
+    ("logit", ""),
+    ("logspace", ""),
+    ("logsumexp", ""),
+    ("long", ""),
+    ("lt", ""),
+    ("lu", ""),
+    ("lu_solve", ""),
+    ("lu_unpack", ""),
+    ("mH", ""),
+    ("mT", ""),
+    ("masked.amax", ""),
+    ("masked.amin", ""),
+    ("masked.argmax", ""),
+    ("masked.argmin", ""),
+    ("masked.cumprod", ""),
+    ("masked.cumsum", ""),
+    ("masked.log_softmax", ""),
+    ("masked.logaddexp", ""),
+    ("masked.logsumexp", ""),
+    ("masked.mean", ""),
+    ("masked.median", ""),
+    ("masked.norm", ""),
+    ("masked.normalize", ""),
+    ("masked.prod", ""),
+    ("masked.softmax", ""),
+    ("masked.softmin", ""),
+    ("masked.std", ""),
+    ("masked.sum", ""),
+    ("masked.var", ""),
+    ("masked_fill", ""),
+    ("masked_scatter", ""),
+    ("masked_select", ""),
+    ("matmul", ""),
+    ("matrix_exp", ""),
+    ("max", "binary"),
+    ("max", "reduction_no_dim"),
+    ("max", "reduction_with_dim"),
+    ("max_pool2d_with_indices_backward", ""),
+    ("maximum", ""),
+    ("mean", ""),
+    ("median", ""),
+    ("meshgrid", "list_of_tensors"),
+    ("meshgrid", "variadic_tensors"),
+    ("min", "binary"),
+    ("min", "reduction_no_dim"),
+    ("min", "reduction_with_dim"),
+    ("minimum", ""),
+    ("mm", ""),
+    ("mode", ""),
+    ("movedim", ""),
+    ("msort", ""),
+    ("mul", ""),
+    ("multinomial", ""),
+    ("mv", ""),
+    ("mvlgamma", "mvlgamma_p_1"),
+    ("mvlgamma", "mvlgamma_p_3"),
+    ("mvlgamma", "mvlgamma_p_5"),
+    ("nan_to_num", ""),
+    ("nanmean", ""),
+    ("nanmedian", ""),
+    ("nanquantile", ""),
+    ("nansum", ""),
+    ("narrow", ""),
+    ("narrow_copy", ""),
+    ("native_batch_norm", ""),
+    ("native_layer_norm", ""),
+    ("ne", ""),
+    ("neg", ""),
+    ("new_empty", ""),
+    ("new_empty_strided", ""),
+    ("new_full", ""),
+    ("new_ones", ""),
+    ("new_zeros", ""),
+    ("nextafter", ""),
+    ("nn.functional._scaled_dot_product_attention", ""),
+    ("nn.functional.adaptive_avg_pool1d", ""),
+    ("nn.functional.adaptive_avg_pool2d", ""),
+    ("nn.functional.adaptive_avg_pool3d", ""),
+    ("nn.functional.adaptive_max_pool1d", ""),
+    ("nn.functional.adaptive_max_pool2d", ""),
+    ("nn.functional.adaptive_max_pool3d", ""),
+    ("nn.functional.alpha_dropout", ""),
+    ("nn.functional.avg_pool1d", ""),
+    ("nn.functional.avg_pool2d", ""),
+    ("nn.functional.avg_pool3d", ""),
+    ("nn.functional.batch_norm", ""),
+    ("nn.functional.batch_norm", "without_cudnn"),
+    ("nn.functional.bilinear", ""),
+    ("nn.functional.binary_cross_entropy", ""),
+    ("nn.functional.binary_cross_entropy_with_logits", ""),
+    ("nn.functional.celu", ""),
+    ("nn.functional.conv1d", ""),
+    ("nn.functional.conv2d", ""),
+    ("nn.functional.conv_transpose1d", ""),
+    ("nn.functional.conv_transpose2d", ""),
+    ("nn.functional.conv_transpose3d", ""),
+    ("nn.functional.cosine_embedding_loss", ""),
+    ("nn.functional.cosine_similarity", ""),
+    ("nn.functional.cross_entropy", ""),
+    ("nn.functional.ctc_loss", ""),
+    ("nn.functional.dropout", ""),
+    ("nn.functional.dropout2d", ""),
+    ("nn.functional.dropout3d", ""),
+    ("nn.functional.elu", ""),
+    ("nn.functional.embedding", ""),
+    ("nn.functional.embedding_bag", ""),
+    ("nn.functional.feature_alpha_dropout", "with_train"),
+    ("nn.functional.feature_alpha_dropout", "without_train"),
+    ("nn.functional.fractional_max_pool2d", ""),
+    ("nn.functional.fractional_max_pool3d", ""),
+    ("nn.functional.gaussian_nll_loss", ""),
+    ("nn.functional.gelu", ""),
+    ("nn.functional.glu", ""),
+    ("nn.functional.grid_sample", ""),
+    ("nn.functional.group_norm", ""),
+    ("nn.functional.hardshrink", ""),
+    ("nn.functional.hardsigmoid", ""),
+    ("nn.functional.hardswish", ""),
+    ("nn.functional.hardtanh", ""),
+    ("nn.functional.hinge_embedding_loss", ""),
+    ("nn.functional.huber_loss", ""),
+    ("nn.functional.instance_norm", ""),
+    ("nn.functional.interpolate", "area"),
+    ("nn.functional.interpolate", "bicubic"),
+    ("nn.functional.interpolate", "bilinear"),
+    ("nn.functional.interpolate", "linear"),
+    ("nn.functional.interpolate", "nearest"),
+    ("nn.functional.interpolate", "trilinear"),
+    ("nn.functional.kl_div", ""),
+    ("nn.functional.l1_loss", ""),
+    ("nn.functional.layer_norm", ""),
+    ("nn.functional.leaky_relu", ""),
+    ("nn.functional.linear", ""),
+    ("nn.functional.local_response_norm", ""),
+    ("nn.functional.logsigmoid", ""),
+    ("nn.functional.margin_ranking_loss", ""),
+    ("nn.functional.max_pool1d", ""),
+    ("nn.functional.max_pool2d", ""),
+    ("nn.functional.max_pool3d", ""),
+    ("nn.functional.max_unpool1d", ""),
+    ("nn.functional.max_unpool1d", "grad"),
+    ("nn.functional.max_unpool2d", ""),
+    ("nn.functional.max_unpool2d", "grad"),
+    ("nn.functional.max_unpool3d", ""),
+    ("nn.functional.max_unpool3d", "grad"),
+    ("nn.functional.mish", ""),
+    ("nn.functional.mse_loss", ""),
+    ("nn.functional.multi_margin_loss", ""),
+    ("nn.functional.multilabel_margin_loss", ""),
+    ("nn.functional.multilabel_soft_margin_loss", ""),
+    ("nn.functional.nll_loss", ""),
+    ("nn.functional.normalize", ""),
+    ("nn.functional.one_hot", ""),
+    ("nn.functional.pad", "circular"),
+    ("nn.functional.pad", "constant"),
+    ("nn.functional.pad", "reflect"),
+    ("nn.functional.pad", "replicate"),
+    ("nn.functional.pairwise_distance", ""),
+    ("nn.functional.pdist", ""),
+    ("nn.functional.pixel_shuffle", ""),
+    ("nn.functional.pixel_unshuffle", ""),
+    ("nn.functional.poisson_nll_loss", ""),
+    ("nn.functional.prelu", ""),
+    ("nn.functional.relu", ""),
+    ("nn.functional.relu6", ""),
+    ("nn.functional.rrelu", ""),
+    ("nn.functional.selu", ""),
+    ("nn.functional.silu", ""),
+    ("nn.functional.silu", "complex"),
+    ("nn.functional.smooth_l1_loss", ""),
+    ("nn.functional.soft_margin_loss", ""),
+    ("nn.functional.softmin", ""),
+    ("nn.functional.softmin", "with_dtype"),
+    ("nn.functional.softplus", ""),
+    ("nn.functional.softshrink", ""),
+    ("nn.functional.softsign", ""),
+    ("nn.functional.tanhshrink", ""),
+    ("nn.functional.threshold", ""),
+    ("nn.functional.triplet_margin_loss", ""),
+    ("nn.functional.triplet_margin_with_distance_loss", ""),
+    ("nn.functional.unfold", ""),
+    ("nn.functional.upsample_bilinear", ""),
+    ("nn.functional.upsample_nearest", ""),
+    ("nonzero", ""),
+    ("norm", ""),
+    ("norm", "fro"),
+    ("norm", "inf"),
+    ("norm", "nuc"),
+    ("normal", ""),
+    ("normal", "number_mean"),
+    ("ones", ""),
+    ("ones_like", ""),
+    ("ormqr", ""),
+    ("outer", ""),
+    ("pca_lowrank", ""),
+    ("permute", ""),
+    ("pinverse", ""),
+    ("polar", ""),
+    ("polygamma", "polygamma_n_0"),
+    ("polygamma", "polygamma_n_1"),
+    ("polygamma", "polygamma_n_2"),
+    ("polygamma", "polygamma_n_3"),
+    ("polygamma", "polygamma_n_4"),
+    ("positive", ""),
+    ("pow", ""),
+    ("prod", ""),
+    ("put", ""),
+    ("qr", ""),
+    ("quantile", ""),
+    ("rad2deg", ""),
+    ("rand_like", ""),
+    ("randint", ""),
+    ("randint_like", ""),
+    ("randn", ""),
+    ("randn_like", ""),
+    ("ravel", ""),
+    ("real", ""),
+    ("reciprocal", ""),
+    ("remainder", ""),
+    ("renorm", ""),
+    ("repeat", ""),
+    ("repeat_interleave", ""),
+    ("reshape", ""),
+    ("reshape_as", ""),
+    ("resize_", ""),
+    ("resize_as_", ""),
+    ("resolve_conj", ""),
+    ("resolve_neg", ""),
+    ("roll", ""),
+    ("rot90", ""),
+    ("round", ""),
+    ("round", "decimals_0"),
+    ("round", "decimals_3"),
+    ("round", "decimals_neg_3"),
+    ("rsqrt", ""),
+    ("rsub", ""),
+    ("scalar_tensor", ""),
+    ("scatter", ""),
+    ("scatter_add", ""),
+    ("scatter_reduce", "amax"),
+    ("scatter_reduce", "amin"),
+    ("scatter_reduce", "mean"),
+    ("scatter_reduce", "prod"),
+    ("scatter_reduce", "sum"),
+    ("searchsorted", ""),
+    ("segment_reduce", "lengths"),
+    ("segment_reduce", "offsets"),
+    ("select", ""),
+    ("select_scatter", ""),
+    ("sgn", ""),
+    ("short", ""),
+    ("sigmoid", ""),
+    ("sign", ""),
+    ("signal.windows.cosine", ""),
+    ("signal.windows.exponential", ""),
+    ("signal.windows.gaussian", ""),
+    ("signal.windows.kaiser", ""),
+    ("signbit", ""),
+    ("sin", ""),
+    ("sinc", ""),
+    ("sinh", ""),
+    ("slice", ""),
+    ("slice_scatter", ""),
+    ("softmax", ""),
+    ("softmax", "with_dtype"),
+    ("sort", ""),
+    ("sparse.sampled_addmm", ""),
+    ("special.airy_ai", ""),
+    ("special.bessel_j0", ""),
+    ("special.bessel_j1", ""),
+    ("special.bessel_y0", ""),
+    ("special.bessel_y1", ""),
+    ("special.chebyshev_polynomial_t", ""),
+    ("special.chebyshev_polynomial_u", ""),
+    ("special.chebyshev_polynomial_v", ""),
+    ("special.chebyshev_polynomial_w", ""),
+    ("special.entr", ""),
+    ("special.erfcx", ""),
+    ("special.hermite_polynomial_h", ""),
+    ("special.hermite_polynomial_he", ""),
+    ("special.i0e", ""),
+    ("special.i1", ""),
+    ("special.i1e", ""),
+    ("special.laguerre_polynomial_l", ""),
+    ("special.legendre_polynomial_p", ""),
+    ("special.log_ndtr", ""),
+    ("special.modified_bessel_i0", ""),
+    ("special.modified_bessel_i1", ""),
+    ("special.modified_bessel_k0", ""),
+    ("special.modified_bessel_k1", ""),
+    ("special.ndtr", ""),
+    ("special.ndtri", ""),
+    ("special.polygamma", "special_polygamma_n_0"),
+    ("special.scaled_modified_bessel_k0", ""),
+    ("special.scaled_modified_bessel_k1", ""),
+    ("special.shifted_chebyshev_polynomial_t", ""),
+    ("special.shifted_chebyshev_polynomial_u", ""),
+    ("special.shifted_chebyshev_polynomial_v", ""),
+    ("special.shifted_chebyshev_polynomial_w", ""),
+    ("special.spherical_bessel_j0", ""),
+    ("special.xlog1py", ""),
+    ("special.zeta", ""),
+    ("split", ""),
+    ("split", "list_args"),
+    ("split_with_sizes", ""),
+    ("sqrt", ""),
+    ("square", ""),
+    ("squeeze", ""),
+    ("stack", ""),
+    ("std", ""),
+    ("std_mean", ""),
+    ("stft", ""),
+    ("sub", ""),
+    ("sum", ""),
+    ("sum_to_size", ""),
+    ("svd", ""),
+    ("svd_lowrank", ""),
+    ("symeig", ""),
+    ("t", ""),
+    ("take", ""),
+    ("take_along_dim", ""),
+    ("tan", ""),
+    ("tanh", ""),
+    ("tensor_split", ""),
+    ("tensordot", ""),
+    ("tile", ""),
+    ("to", ""),
+    ("to_sparse", ""),
+    ("topk", ""),
+    ("trace", ""),
+    ("transpose", ""),
+    ("trapezoid", ""),
+    ("trapz", ""),
+    ("triangular_solve", ""),
+    ("tril", ""),
+    ("tril_indices", ""),
+    ("triu", ""),
+    ("triu_indices", ""),
+    ("true_divide", ""),
+    ("trunc", ""),
+    ("unbind", ""),
+    ("unflatten", ""),
+    ("unfold", ""),
+    ("unfold_copy", ""),
+    ("uniform", ""),
+    ("unique", ""),
+    ("unique_consecutive", ""),
+    ("unsqueeze", ""),
+    ("var", ""),
+    ("var_mean", ""),
+    ("vdot", ""),
+    ("view", ""),
+    ("view_as", ""),
+    ("view_as_complex", ""),
+    ("view_as_real", ""),
+    ("vsplit", ""),
+    ("vstack", ""),
+    ("where", ""),
+    ("xlogy", ""),
+    ("zero_", ""),
+    ("zeros", ""),
+    ("zeros_like", ""),
+}
+
+
+def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool:
+    return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta
+
+
+dtensor_lagging_op_db: List[OpInfo] = [
+    opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo)
+]
diff --git a/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py
new file mode 100644
index 0000000..f684f77
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py
@@ -0,0 +1,67 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple
+from torch.testing._internal.common_methods_invocations import op_db
+
+
+def num_leading_spaces(line: str) -> int:
+    result = len(line) - len(line.lstrip())
+    # Empty space handling
+    if result == 0:
+        return 999999
+    return result
+
+
+def deindent(code: str) -> str:
+    lines = code.split("\n")
+    min_leading_spaces = min(map(num_leading_spaces, lines))
+    lines = [line[min_leading_spaces:] for line in lines]
+    return "\n".join(lines)
+
+
+if __name__ == "__main__":
+    supported: List[Tuple[str, str]] = [
+        (opinfo.name, opinfo.variant_test_name) for opinfo in op_db
+    ]
+    supported = sorted(supported)
+    print(
+        deindent(
+            """\
+    # Copyright (c) Facebook, Inc. and its affiliates.
+    # All rights reserved.
+    #
+    # This source code is licensed under the BSD-style license found in the
+    # LICENSE file in the root directory of this source tree.
+    from typing import List
+    from torch.testing._internal.common_methods_invocations import op_db, OpInfo
+    # Generated from test/gen_dtensor_op_db.py via
+    # python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
+    #
+    # This approach is copied from functorch:
+    # People add new OpInfos to PyTorch all the time.
+    # We want them to be able to add OpInfos without breaking our CI.
+    # To achieve this, we keep our OpInfo library behind that of Pytorch's and
+    # we periodically update our OpInfo library by regenerating this file"""
+        )
+    )
+
+    print("_dtensor_lagging_meta = {")
+    for name, variant in supported:
+        print(f"    {(name, variant)},")
+    print("}")
+
+    print(
+        deindent(
+            """\
+    def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool:
+        return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta
+
+    dtensor_lagging_op_db: List[OpInfo] = [
+        opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo)
+    ]"""
+        )
+    )