[dtensor] PART 7: move remaining DTensor tests to core distributed (#88179)
This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed
part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88179
Approved by: https://github.com/aazzolini
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
new file mode 100644
index 0000000..22ae580
--- /dev/null
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -0,0 +1,704 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+import sys
+import unittest
+import warnings
+
+from torch.overrides import resolve_name
+from torch.utils._pytree import tree_flatten, tree_map
+from torch.testing._internal.common_utils import (
+ suppress_warnings,
+ TEST_WITH_ASAN,
+ run_tests,
+)
+import torch.distributed as dist
+from torch.testing._internal.common_device_type import (
+ ops,
+ instantiate_device_type_tests,
+)
+import torch.testing._internal.common_methods_invocations as common_ops
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+
+from torch.distributed._tensor import DTensor, DeviceMesh, Replicate
+from torch.testing._internal.distributed._tensor.dtensor_lagging_op_db import dtensor_lagging_op_db
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ TEST_SKIPS,
+ DTensorConverter,
+ DEVICE_TYPE,
+)
+
+# rewrite common size variables to sth can be sharded evenly
+# we can enable uneven shards later, but need to adjust more on
+# sample inputs (i.e. view/reshape need to adjust shape size as well)
+common_ops.L = 24
+common_ops.M = 12
+common_ops.S = 4
+common_ops.XS = 2
+
+
+def assert_ref_dtensor_equal(test_case, dtensor_rs, rs):
+ flat_dtensor_rs, _ = tree_flatten(dtensor_rs)
+ flat_rs, _ = tree_flatten(rs)
+ test_case.assertEqual(len(flat_dtensor_rs), len(flat_rs))
+ for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):
+
+ if not isinstance(r, torch.Tensor):
+ continue
+
+ test_case.assertIsInstance(dtensor_r, torch.Tensor)
+ test_case.assertEqual(
+ dtensor_r.shape,
+ r.shape,
+ f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
+ )
+ test_case.assertEqual(
+ dtensor_r.requires_grad,
+ r.requires_grad,
+ "op result requires_grad mismatch!"
+ f"original requires_grad: {r.requires_grad}, "
+ f"dtensor requires_grad: {dtensor_r.requires_grad}",
+ )
+
+ test_case.assertEqual(dtensor_r.to_local(), r)
+
+
+# Copied from functorch
+def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
+ return (op_name, variant_name, device_type, dtypes, True)
+
+
+def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
+ return (op_name, variant_name, device_type, dtypes, False)
+
+
+def skipOps(test_case_name, base_test_name, to_skip):
+ all_opinfos = dtensor_lagging_op_db
+ for xfail in to_skip:
+ op_name, variant_name, device_type, dtypes, expected_failure = xfail
+ matching_opinfos = [
+ o
+ for o in all_opinfos
+ if o.name == op_name and o.variant_test_name == variant_name
+ ]
+ assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
+ for opinfo in matching_opinfos:
+ decorators = list(opinfo.decorators)
+ if expected_failure:
+ decorator = DecorateInfo(
+ unittest.expectedFailure,
+ test_case_name,
+ base_test_name,
+ device_type=device_type,
+ dtypes=dtypes,
+ )
+ decorators.append(decorator)
+ else:
+ decorator = DecorateInfo(
+ unittest.skip("Skipped!"),
+ test_case_name,
+ base_test_name,
+ device_type=device_type,
+ dtypes=dtypes,
+ )
+ decorators.append(decorator)
+ opinfo.decorators = tuple(decorators)
+
+ # This decorator doesn't modify fn in any way
+ def wrapped(fn):
+ return fn
+
+ return wrapped
+
+
+# Re-generate this failed list, turn on dry_run of the below func
+# check_dtensor_func(self, test, op, dry_run=True), then run sth
+# like python test/spmd/tensor/test_dtensor_ops.py > failed.expect
+dtensor_fails = {
+ # these sometimes pass and sometimes fail
+ # we need to remove many of them from list once op
+ # get full support with varying sharding specs
+ xfail("__getitem__"),
+ xfail("__rsub__"),
+ xfail("masked.amax"),
+ xfail("masked.amin"),
+ xfail("masked.argmax"),
+ xfail("masked.argmin"),
+ xfail("masked.cumprod"),
+ xfail("masked.cumsum"),
+ xfail("masked.log_softmax"),
+ xfail("masked.logaddexp"),
+ xfail("masked.logsumexp"),
+ xfail("masked.median"),
+ xfail("masked.norm"),
+ xfail("masked.prod"),
+ xfail("masked.softmin"),
+ xfail("masked.softmax"),
+ xfail("masked.sum"),
+ xfail("addbmm"),
+ xfail("addmv"),
+ xfail("addr"),
+ xfail("all"),
+ xfail("allclose"),
+ xfail("amax"),
+ xfail("amin"),
+ xfail("aminmax"),
+ xfail("any"),
+ xfail("arange"),
+ xfail("argmax"),
+ xfail("argmin"),
+ xfail("argsort"),
+ xfail("as_strided"),
+ xfail("as_strided_scatter"),
+ xfail("baddbmm"),
+ xfail("bernoulli"),
+ xfail("block_diag"),
+ xfail("broadcast_shapes"),
+ xfail("cat"),
+ xfail("cartesian_prod"),
+ xfail("cdist"),
+ xfail("cholesky"),
+ xfail("cholesky_inverse"),
+ xfail("cholesky_solve"),
+ xfail("chunk"),
+ xfail("clamp"),
+ xfail("clamp_max"),
+ xfail("clamp_min"),
+ xfail("column_stack"),
+ xfail("combinations"),
+ xfail("complex"),
+ xfail("constant_pad_nd"),
+ xfail("copysign"),
+ xfail("corrcoef"),
+ xfail("count_nonzero"),
+ xfail("cov"),
+ xfail("cross"),
+ xfail("cummax"),
+ xfail("cummin"),
+ xfail("cumsum"),
+ xfail("cumulative_trapezoid"),
+ xfail("diag"),
+ xfail("diag_embed"),
+ xfail("diagflat"),
+ xfail("diagonal"),
+ xfail("diagonal_copy"),
+ xfail("diagonal_scatter"),
+ xfail("diff"),
+ xfail("dist"),
+ xfail("dot"),
+ xfail("dstack"),
+ xfail("einsum"),
+ xfail("empty"),
+ xfail("empty_like"),
+ xfail("eq"),
+ xfail("eye"),
+ xfail("fft.fft2"),
+ xfail("fft.fft"),
+ xfail("fft.fftn"),
+ xfail("fft.fftshift"),
+ xfail("fft.ifft2"),
+ xfail("fft.ifft"),
+ xfail("fft.ifftshift"),
+ xfail("fft.ihfft2"),
+ xfail("fft.ihfft"),
+ xfail("fft.ihfftn"),
+ xfail("fft.irfft2"),
+ xfail("fft.irfftn"),
+ xfail("fft.rfft2"),
+ xfail("fft.rfft"),
+ xfail("fft.rfftn"),
+ xfail("flip"),
+ xfail("fliplr"),
+ xfail("flipud"),
+ xfail("floor_divide"),
+ xfail("fmax"),
+ xfail("fmin"),
+ xfail("frexp"),
+ xfail("full"),
+ xfail("gather"),
+ xfail("geqrf"),
+ xfail("gradient"),
+ xfail("heaviside"),
+ xfail("histc"),
+ xfail("histogram"),
+ xfail("histogramdd"),
+ xfail("hstack"),
+ xfail("index_add"),
+ xfail("index_copy"),
+ xfail("index_fill"),
+ xfail("index_put"),
+ xfail("index_reduce"),
+ xfail("index_select"),
+ xfail("isfinite"),
+ xfail("isin"),
+ xfail("isinf"),
+ xfail("isnan"),
+ xfail("isneginf"),
+ xfail("isposinf"),
+ xfail("kthvalue"),
+ xfail("linalg.cholesky"),
+ xfail("linalg.cholesky_ex"),
+ xfail("linalg.cond"),
+ xfail("linalg.cross"),
+ xfail("linalg.det"),
+ xfail("linalg.det", "singular"),
+ xfail("linalg.eig"),
+ xfail("linalg.eigh"),
+ xfail("linalg.eigvals"),
+ xfail("linalg.eigvalsh"),
+ xfail("linalg.householder_product"),
+ xfail("linalg.inv"),
+ xfail("linalg.inv_ex"),
+ xfail("linalg.ldl_factor"),
+ xfail("linalg.ldl_factor_ex"),
+ xfail("linalg.ldl_solve"),
+ xfail("linalg.lstsq"),
+ xfail("linalg.lstsq", "grad_oriented"),
+ xfail("linalg.lu"),
+ xfail("linalg.lu_factor"),
+ xfail("linalg.lu_factor_ex"),
+ xfail("linalg.lu_solve"),
+ xfail("linalg.matrix_norm"),
+ xfail("linalg.matrix_power"),
+ xfail("linalg.matrix_rank"),
+ xfail("linalg.matrix_rank", "hermitian"),
+ xfail("linalg.multi_dot"),
+ xfail("linalg.norm"),
+ xfail("linalg.norm", "subgradients_at_zero"),
+ xfail("linalg.pinv"),
+ xfail("linalg.pinv", "hermitian"),
+ xfail("linalg.qr"),
+ xfail("linalg.slogdet"),
+ xfail("linalg.solve"),
+ xfail("linalg.solve_ex"),
+ xfail("linalg.solve_triangular"),
+ xfail("linalg.svd"),
+ xfail("linalg.svdvals"),
+ xfail("linalg.tensorinv"),
+ xfail("linalg.tensorsolve"),
+ xfail("linalg.vander"),
+ xfail("linalg.vecdot"),
+ xfail("linalg.vector_norm"),
+ xfail("linspace"),
+ xfail("log_softmax"),
+ xfail("log_softmax", "with_dtype"),
+ xfail("logcumsumexp"),
+ xfail("logdet"),
+ xfail("logical_not"),
+ xfail("logspace"),
+ xfail("logsumexp"),
+ xfail("lt"),
+ xfail("lu"),
+ xfail("lu_solve"),
+ xfail("lu_unpack"),
+ xfail("masked_fill"),
+ xfail("masked_scatter"),
+ xfail("masked_select"),
+ xfail("matrix_exp"),
+ xfail("max", "binary"),
+ xfail("max", "reduction_no_dim"),
+ xfail("max", "reduction_with_dim"),
+ xfail("maximum"),
+ xfail("median"),
+ xfail("min", "binary"),
+ xfail("min", "reduction_no_dim"),
+ xfail("min", "reduction_with_dim"),
+ xfail("minimum"),
+ xfail("mode"),
+ xfail("msort"),
+ xfail("multinomial"),
+ xfail("mv"),
+ xfail("max_pool2d_with_indices_backward", ""),
+ xfail("nanmean"),
+ xfail("nanmedian"),
+ xfail("nanquantile"),
+ xfail("nansum"),
+ xfail("native_batch_norm"),
+ xfail("native_layer_norm"),
+ xfail("narrow_copy"),
+ xfail("ne"),
+ xfail("new_empty"),
+ xfail("new_empty_strided"),
+ xfail("transpose"),
+ xfail("nn.functional.adaptive_avg_pool1d"),
+ xfail("nn.functional.adaptive_avg_pool2d"),
+ xfail("nn.functional.adaptive_avg_pool3d"),
+ xfail("nn.functional.adaptive_max_pool1d"),
+ xfail("nn.functional.adaptive_max_pool2d"),
+ xfail("nn.functional.adaptive_max_pool3d"),
+ xfail("nn.functional.alpha_dropout"),
+ xfail("nn.functional.avg_pool1d"),
+ xfail("nn.functional.avg_pool2d"),
+ xfail("nn.functional.avg_pool3d"),
+ xfail("nn.functional.batch_norm"),
+ xfail("nn.functional.batch_norm", "without_cudnn"),
+ xfail("nn.functional.bilinear"),
+ xfail("nn.functional.binary_cross_entropy"),
+ xfail("nn.functional.binary_cross_entropy_with_logits"),
+ xfail("nn.functional.celu"),
+ xfail("nn.functional.conv1d"),
+ xfail("nn.functional.conv2d"),
+ xfail("nn.functional.conv_transpose1d"),
+ xfail("nn.functional.conv_transpose2d"),
+ xfail("nn.functional.conv_transpose3d"),
+ xfail("nn.functional.cosine_similarity"),
+ xfail("nn.functional.cross_entropy"),
+ xfail("nn.functional.ctc_loss"),
+ xfail("nn.functional.dropout"),
+ xfail("nn.functional.dropout2d"),
+ xfail("nn.functional.dropout3d"),
+ xfail("nn.functional.elu"),
+ xfail("nn.functional.fractional_max_pool2d"),
+ xfail("nn.functional.fractional_max_pool3d"),
+ xfail("nn.functional.gaussian_nll_loss"),
+ xfail("nn.functional.glu"),
+ xfail("nn.functional.grid_sample"),
+ xfail("nn.functional.group_norm"),
+ xfail("nn.functional.hardshrink"),
+ xfail("nn.functional.hardsigmoid"),
+ xfail("nn.functional.hardswish"),
+ xfail("nn.functional.hardtanh"),
+ xfail("nn.functional.huber_loss"),
+ xfail("nn.functional.instance_norm"),
+ xfail("nn.functional.interpolate", "area"),
+ xfail("nn.functional.interpolate", "bicubic"),
+ xfail("nn.functional.interpolate", "bilinear"),
+ xfail("nn.functional.interpolate", "linear"),
+ xfail("nn.functional.interpolate", "nearest"),
+ xfail("nn.functional.interpolate", "trilinear"),
+ xfail("nn.functional.layer_norm"),
+ xfail("nn.functional.leaky_relu"),
+ xfail("nn.functional.linear"),
+ xfail("nn.functional.local_response_norm"),
+ xfail("nn.functional.logsigmoid"),
+ xfail("nn.functional.margin_ranking_loss"),
+ xfail("nn.functional.max_pool1d"),
+ xfail("nn.functional.max_pool2d"),
+ xfail("nn.functional.max_pool3d"),
+ xfail("nn.functional.max_unpool1d"),
+ xfail("nn.functional.max_unpool1d", "grad"),
+ xfail("nn.functional.max_unpool2d"),
+ xfail("nn.functional.max_unpool2d", "grad"),
+ xfail("nn.functional.max_unpool3d"),
+ xfail("nn.functional.max_unpool3d", "grad"),
+ xfail("nn.functional.mish"),
+ xfail("nn.functional.mse_loss"),
+ xfail("nn.functional.multi_margin_loss"),
+ xfail("nn.functional.multilabel_margin_loss"),
+ xfail("nn.functional.multilabel_soft_margin_loss"),
+ xfail("nn.functional.nll_loss"),
+ xfail("nn.functional.normalize"),
+ xfail("nn.functional.pad", "circular"),
+ xfail("nn.functional.pad", "constant"),
+ xfail("nn.functional.pad", "reflect"),
+ xfail("nn.functional.pad", "replicate"),
+ xfail("nn.functional.pairwise_distance"),
+ xfail("nn.functional.pdist"),
+ xfail("nn.functional.pixel_shuffle"),
+ xfail("nn.functional.pixel_unshuffle"),
+ xfail("nn.functional.poisson_nll_loss"),
+ xfail("nn.functional.prelu"),
+ xfail("nn.functional.relu6"),
+ xfail("nn.functional.rrelu"),
+ xfail("nn.functional.selu"),
+ xfail("nn.functional.silu"),
+ xfail("nn.functional.smooth_l1_loss"),
+ xfail("nn.functional.soft_margin_loss"),
+ xfail("nn.functional.softplus"),
+ xfail("nn.functional.softshrink"),
+ xfail("nn.functional.threshold"),
+ xfail("nn.functional.triplet_margin_loss"),
+ xfail("nn.functional.triplet_margin_with_distance_loss"),
+ xfail("nn.functional.unfold"),
+ xfail("nn.functional.upsample_bilinear"),
+ xfail("nn.functional.upsample_nearest"),
+ xfail("nonzero"),
+ xfail("norm"),
+ xfail("norm", "fro"),
+ xfail("norm", "inf"),
+ xfail("norm", "nuc"),
+ xfail("normal"),
+ xfail("normal", "number_mean"),
+ xfail("ormqr"),
+ xfail("ones"),
+ xfail("pca_lowrank"),
+ xfail("pinverse"),
+ xfail("polar"),
+ xfail("put"),
+ xfail("qr"),
+ xfail("quantile"),
+ xfail("rad2deg"),
+ xfail("rand_like"),
+ xfail("randint_like"),
+ xfail("randint"),
+ xfail("randn"),
+ xfail("randn_like"),
+ xfail("renorm"),
+ xfail("repeat_interleave"),
+ xfail("resize_"),
+ xfail("resize_as_"),
+ xfail("roll"),
+ xfail("rot90"),
+ xfail("rsub"),
+ xfail("scalar_tensor"),
+ xfail("scatter_add"),
+ xfail("scatter"),
+ xfail("scatter_reduce", "amax"),
+ xfail("scatter_reduce", "amin"),
+ xfail("scatter_reduce", "mean"),
+ xfail("scatter_reduce", "prod"),
+ xfail("scatter_reduce", "sum"),
+ xfail("searchsorted"),
+ xfail("select"),
+ xfail("select_scatter"),
+ xfail("signbit"),
+ xfail("sort"),
+ xfail("sparse.sampled_addmm"),
+ xfail("special.airy_ai"),
+ xfail("special.bessel_j0"),
+ xfail("special.bessel_j1"),
+ xfail("special.bessel_y0"),
+ xfail("special.bessel_y1"),
+ xfail("special.chebyshev_polynomial_t"),
+ xfail("special.chebyshev_polynomial_u"),
+ xfail("special.entr"),
+ xfail("special.erfcx"),
+ xfail("special.hermite_polynomial_h"),
+ xfail("special.hermite_polynomial_he"),
+ xfail("special.i0e"),
+ xfail("special.i1"),
+ xfail("special.i1e"),
+ xfail("special.laguerre_polynomial_l"),
+ xfail("special.log_ndtr"),
+ xfail("special.modified_bessel_i0"),
+ xfail("special.modified_bessel_i1"),
+ xfail("special.modified_bessel_k0"),
+ xfail("special.modified_bessel_k1"),
+ xfail("special.ndtri"),
+ xfail("special.scaled_modified_bessel_k0"),
+ xfail("special.scaled_modified_bessel_k1"),
+ xfail("special.spherical_bessel_j0"),
+ xfail("special.xlog1py"),
+ xfail("special.zeta"),
+ xfail("split"),
+ xfail("split", "list_args"),
+ xfail("split_with_sizes"),
+ xfail("signal.windows.cosine"),
+ xfail("signal.windows.exponential"),
+ xfail("signal.windows.gaussian"),
+ xfail("signal.windows.kaiser"),
+ xfail("squeeze"),
+ xfail("stack"),
+ xfail("std"),
+ xfail("std_mean"),
+ xfail("stft"),
+ xfail("svd"),
+ xfail("svd_lowrank"),
+ xfail("symeig"),
+ xfail("t"),
+ xfail("take_along_dim"),
+ xfail("take"),
+ xfail("tensor_split"),
+ xfail("to_sparse"),
+ xfail("topk"),
+ xfail("trace"),
+ xfail("trapezoid"),
+ xfail("trapz"),
+ xfail("triangular_solve"),
+ xfail("tril"),
+ xfail("triu"),
+ xfail("unbind"),
+ xfail("unfold"),
+ xfail("unfold_copy"),
+ xfail("uniform"),
+ xfail("unflatten"),
+ xfail("unique_consecutive"),
+ xfail("unique"),
+ xfail("var_mean"),
+ xfail("vdot"),
+ xfail("view_as_complex"),
+ xfail("vstack"),
+ xfail("zeros"),
+ # ops inside this might even fail without dtensor
+ # tests, as we rescale op db common test size factor (i.e. L, M, S)
+ # which triggered the orignal function run failures with input
+ # generation becomes wrong, we skip them for now but should enable later.
+ # TODO: need to clean this list and remove all cases
+ skip("argwhere"),
+ skip("cumprod"),
+ skip("__rmatmul__"),
+ skip("meshgrid", "list_of_tensors"),
+ skip("meshgrid", "variadic_tensors"),
+ skip("nn.functional._scaled_dot_product_attention"),
+ skip("nn.functional.softmin"),
+ skip("nn.functional.embedding"),
+ skip("nn.functional.embedding_bag"),
+ skip("nn.functional.feature_alpha_dropout", "with_train"),
+ skip("nn.functional.feature_alpha_dropout", "without_train"),
+ skip("nn.functional.hinge_embedding_loss"),
+ skip("nn.functional.cosine_embedding_loss"),
+ skip("fft.hfft"),
+ skip("fft.hfft2"),
+ skip("fft.hfft2"),
+ skip("fft.hfftn"),
+ skip("fft.ifftn"),
+ skip("fft.irfft"),
+ skip("istft"),
+ skip("isclose"),
+ skip("isreal"),
+ skip("matmul"),
+ skip("masked.mean"),
+ skip("masked.var"),
+ skip("masked.std"),
+ skip("masked.normalize"),
+ skip("prod"),
+ skip("segment_reduce", "lengths"),
+ skip("segment_reduce", "offsets"),
+}
+
+
+# Add a list of ops that are currently failing BW pass
+skip_bw = [
+ None, # corresponds to the transpose ops 'H' and 'T'
+ "torch.bucketize",
+ "torch.conj_physical",
+ "torch.eq",
+ "torch.isfinite",
+ "torch.isnan",
+]
+
+
+def run_dtensor_crossref(test_case, func, args, kwargs):
+ to_dtensor = DTensorConverter(test_case.mesh, args, kwargs)
+
+ # TODO: also handle cases where func raise an exception
+ rs = func(*args, **kwargs)
+
+ def to_replicate(e: object) -> object:
+ return (
+ e.redistribute(test_case.mesh, test_case.mesh.ndim * [Replicate()])
+ if isinstance(e, DTensor)
+ else e
+ )
+
+ try:
+ # Suppress warnings, this doesn't matter for test_meta.py
+ # but it does matter if you want to use this decorator
+ # for cross-ref testing, as some tests may be looking at
+ # errors
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ # for every comb of sharding choices, we test if it works
+ for dtensor_args, dtensor_kwargs in to_dtensor:
+ # Only attempt if we managed to convert all tensors to DTensor
+ # (if any of them failed, we're in a mixed tensor situation and
+ # this is not allowed in DTensor)
+ if to_dtensor.successful():
+ # Handle special cases first if there's any
+ # Suppress warnings, this doesn't matter for test_meta.py
+ # but it does matter if you want to use this decorator
+ # for cross-ref testing, as some tests may be looking at
+ # errors
+ dtensor_rs = func(*dtensor_args, **dtensor_kwargs)
+
+ # we need to skip tests containing tensors of zero elmeents for now.
+ # see issue: https://github.com/pytorch/tau/issues/470
+ # TODO remove this once issue above fixed.
+ flat_args, _ = tree_flatten(dtensor_rs)
+ if any(
+ isinstance(e, torch.Tensor) and e.numel() == 0
+ for e in flat_args
+ ):
+ continue
+
+ # redistribute/all_gather the results to compare with normal output
+ dtensor_rs = tree_map(to_replicate, dtensor_rs)
+ try:
+ if resolve_name(func) not in skip_bw:
+ if isinstance(dtensor_rs, DTensor):
+ dtensor_rs.to_local().sum().backward()
+ elif isinstance(dtensor_rs, tuple):
+ dtensor_rs[0].to_local().sum().backward()
+
+ except Exception as e:
+ # TODO(anj): Remove this guard exception after gaining more confidence.
+ if torch.distributed.get_rank() == 0:
+ print(
+ f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
+ )
+ assert_ref_dtensor_equal(test_case, dtensor_rs, rs)
+ else:
+ raise RuntimeError(
+ f"failed to convert args to DTensor; "
+ f"originally (*{args}, **{kwargs})"
+ )
+ except Exception as e:
+ raise RuntimeError(
+ f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})"
+ ) from e
+
+ return rs
+
+
+def check_dtensor_func(test_case, test_func, opinfo, dry_run=False):
+ try:
+ test_func()
+ except Exception:
+ test_case.destroy_pg()
+ if not dry_run:
+ raise
+ if dist.get_rank() == 0:
+ if opinfo.variant_test_name:
+ print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
+ else:
+ print(f"xfail('{opinfo.name}'),")
+ else:
+ test_case.destroy_pg()
+
+
+class TestDTensorOps(DTensorTestBase):
+ @property
+ def world_size(self) -> int:
+ return 4
+
+ # only allow float dytpe for now, we can relax this constraint
+ # when feel necessary later (i.e when adding quantization support).
+ @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
+ @suppress_warnings
+ @ops(dtensor_lagging_op_db, allowed_dtypes=(torch.float,))
+ @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
+ def test_dtensor_op_db(self, dtype, op):
+ pg_backend = "nccl" if DEVICE_TYPE == "cuda" else "gloo"
+ if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size:
+ sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+ self.init_pg(backend=pg_backend)
+ self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
+
+ # test each op with dist tensor inputs and normal inputs
+ def test():
+ samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True)
+ for sample_input in samples:
+ args = [sample_input.input] + list(sample_input.args)
+ kwargs = sample_input.kwargs
+
+ run_dtensor_crossref(self, op.op, args, kwargs)
+ # we need to figure out a way to test the out variant, out variant testing
+ # is tricky, as we need to pre allocate the dtensor out, some of them rely
+ # on sharding placements to be pre-known (i.e. mm.out)
+ # if isinstance(expected, torch.Tensor) and op.supports_out:
+ # func(*args, **kwargs, out=expected)
+
+ check_dtensor_func(self, test, op)
+
+
+# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
+instantiate_device_type_tests(
+ TestDTensorOps, globals(), only_for=(DEVICE_TYPE,)
+)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py
new file mode 100644
index 0000000..1ba3f6d
--- /dev/null
+++ b/test/distributed/_tensor/test_tensor_ops.py
@@ -0,0 +1,365 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorConverter,
+ DTensorTestBase,
+ with_comms,
+)
+from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor
+from torch.distributed._tensor.placement_types import Shard, Replicate, _Partial
+
+
+class DistTensorOpsTest(DTensorTestBase):
+ @with_comms
+ def test_aten_contiguous(self):
+ # this op not covered by dtensor_ops
+ mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ self._test_op(
+ mesh,
+ lambda x: torch.ops.aten.contiguous(x),
+ torch.randn(16, 32),
+ )
+
+ @with_comms
+ def test_detach(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ tensor_to_detach = torch.randn(12, 8, requires_grad=True)
+ mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec)
+ detached_mat = mat.detach()
+ self.assertFalse(detached_mat is mat)
+
+ @with_comms
+ def test_clone(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ specs = [[Replicate()], [Shard(0)]]
+ tensor_to_clone = torch.randn(12, 8, requires_grad=True)
+ for spec in specs:
+ mat = distribute_tensor(tensor_to_clone, device_mesh, spec)
+ cloned_mat = mat.clone()
+ self.assertFalse(cloned_mat is mat)
+ self.assertEqual(cloned_mat.to_local(), mat.to_local())
+
+ @with_comms
+ def test_contiguous(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ tensor = torch.rand(3, 5, 6, requires_grad=True)
+ sharding = [Shard(0)]
+ dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
+ self.assertTrue(dist_tensor.is_contiguous())
+ # shard on dim 0 should not change stride (30, 6, 1)
+ self.assertEqual(dist_tensor.stride(), tensor.stride())
+
+ new_dt = dist_tensor.transpose(0, 2)
+ self.assertFalse(new_dt.is_contiguous())
+ self.assertFalse(new_dt.to_local().is_contiguous())
+ # check stride
+ self.assertEqual(new_dt.stride(), (1, 6, 30))
+
+ new_dt = new_dt.contiguous()
+ self.assertTrue(new_dt.is_contiguous())
+ self.assertTrue(new_dt.to_local().is_contiguous())
+ # check stride
+ self.assertEqual(dist_tensor.stride(), tensor.stride())
+
+ # check backward
+ new_dt.to_local().sum().backward()
+ self.assertEqual(tensor.grad, torch.ones(3, 5, 6))
+
+ @with_comms
+ def test_inplace_op(self):
+ mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ input_tensor = torch.randn((12, 3), device=self.device_type)
+ dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)])
+ dt_to_mul = dt_to_add.clone()
+ expected_add_dt = dt_to_add.clone() + 3
+ add_res = dt_to_add.add_(3)
+ expected_mul_dt = dt_to_mul.clone() * 3
+ mul_res = dt_to_mul.mul_(3)
+ # inplace op should be the same instance before and after
+ self.assertTrue(add_res is dt_to_add)
+ self.assertEqual(add_res.to_local(), expected_add_dt.to_local())
+
+ self.assertTrue(mul_res is dt_to_mul)
+ self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local())
+
+ # test inplace op self and other dtensor with other specs
+ # and make sure out spec not change
+ shard_spec = [Shard(0)]
+ partial_spec = [_Partial()]
+ dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec)
+ partial_grad = DTensor.from_local(
+ torch.randn(12, 3), mesh, partial_spec
+ )
+ res = dt_to_inplace_add.add_(partial_grad)
+ self.assertTrue(res is dt_to_inplace_add)
+ self.assertTrue(res.placements == shard_spec)
+
+ @with_comms
+ def test_op_out_variant(self):
+ mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ input_tensor = torch.randn((12, 3), device=self.device_type)
+ sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)])
+ expected_dt = sharded_dt_input.clone() + 3
+ sharded_dt_out = sharded_dt_input.clone()
+ res = torch.add(sharded_dt_input, 3, out=sharded_dt_out)
+ # op out variant should be the same instance before and after
+ self.assertTrue(res is sharded_dt_out)
+ self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local())
+
+ # test op out variant with other spec and make sure out spec not change
+ replica_spec = [Replicate()]
+ replicate_out = distribute_tensor(input_tensor, mesh, replica_spec)
+ expected_dt = replicate_out.clone() + 3
+ res = torch.add(sharded_dt_input, 3, out=replicate_out)
+ self.assertTrue(res is replicate_out)
+ self.assertTrue(res.placements == replica_spec)
+ self.assertEqual(replicate_out.to_local(), expected_dt.to_local())
+
+ @with_comms
+ def test_empty_like(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ empty_like_dt = torch.empty_like(dist_tensor)
+ # empty is not deterministic, so we only check that the shard propagation worked
+ self.assertEqual((4, 8), empty_like_dt.to_local().shape)
+
+ @with_comms
+ def test_fill_inplace(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ full_like_dt = torch.fill_(dist_tensor, 42.0)
+ full_expected = torch.full((4, 8), 42.0)
+ self.assertEqual(full_expected, full_like_dt.to_local())
+ self.assertEqual(full_expected, dist_tensor.to_local())
+
+ @with_comms
+ def test_full_like(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ full_like_dt = torch.full_like(dist_tensor, 42.0)
+ full_expected = torch.full((4, 8), 42.0)
+ self.assertEqual(full_expected, full_like_dt.to_local())
+
+ @with_comms
+ def test_ones_like(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ ones_like_dt = torch.ones_like(dist_tensor)
+ ones_expected = torch.ones(4, 8)
+ self.assertEqual(ones_expected, ones_like_dt.to_local())
+
+ @with_comms
+ def test_ones_like_partial_sum(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [_Partial()]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ assert dist_tensor.shape == (4, 8)
+
+ ones_like_dt = torch.ones_like(dist_tensor)
+ ones_expected = torch.ones(dist_tensor.shape)
+ self.assertEqual(
+ ones_expected,
+ ones_like_dt.redistribute(device_mesh, [Replicate()]).to_local(),
+ )
+
+ @with_comms
+ def test_fill_inplace_partial_sum(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [_Partial()]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ assert dist_tensor.shape == (4, 8)
+
+ torch.fill_(dist_tensor, 42)
+ fill_expected = torch.full(
+ dist_tensor.shape, 42, dtype=input_tensor.dtype
+ )
+ self.assertEqual(
+ fill_expected,
+ dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(),
+ )
+
+ @with_comms
+ def test_zeros_like_partial_sum(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [_Partial()]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ assert dist_tensor.shape == (4, 8)
+
+ zeros_like_dt = torch.zeros_like(dist_tensor)
+ zeros_expected = torch.zeros(dist_tensor.shape)
+ self.assertEqual(
+ zeros_expected,
+ zeros_like_dt.redistribute(device_mesh, [Replicate()]).to_local(),
+ )
+
+ @with_comms
+ def test_zero_inplace(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ zeros_like_dt = torch.zero_(dist_tensor)
+ zeros_expected = torch.zeros(4, 8)
+ self.assertEqual(zeros_expected, zeros_like_dt.to_local())
+ self.assertEqual(zeros_expected, dist_tensor.to_local())
+
+ @with_comms
+ def test_zeros_like(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ input_tensor = torch.randn(4, 8, requires_grad=True)
+ dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ zeros_like_dt = torch.zeros_like(dist_tensor)
+ zeros_expected = torch.zeros(4, 8)
+ self.assertEqual(zeros_expected, zeros_like_dt.to_local())
+
+ def _test_op(self, mesh, op_call, *args, **kwargs):
+ out = op_call(*args, **kwargs)
+ dtc = DTensorConverter(mesh, args, kwargs)
+ for d_args, d_kwargs in dtc:
+ self.assertTrue(dtc.successful())
+ d_out = op_call(*d_args, **d_kwargs)
+ self.assertEqual(
+ d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(),
+ out,
+ )
+
+ @with_comms
+ def test_index(self):
+ meshes = [
+ DeviceMesh(
+ self.device_type, list(range(self.world_size))
+ ), # 1D mesh
+ # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh
+ # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh
+ ]
+ for mesh in meshes:
+ self._test_op(
+ mesh,
+ lambda x, y: x[y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4, 8)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x.index_select(1, y),
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4,)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x.index_select(0, y),
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4,)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x[y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (12,)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x[:, y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4, 8)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x[..., y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y: x[..., y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (4, 8, 16)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[z, y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (12, 8, 12)),
+ torch.randint(2, (12, 8, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[z, :, y],
+ torch.randn(16, 32, 16),
+ torch.randint(5, (12, 8, 12)),
+ torch.randint(2, (12, 8, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[:, z, :, y],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(5, (12, 8, 12)),
+ torch.randint(2, (12, 8, 12)),
+ )
+ # broadcast in inner dimensions
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[:, z, :, y],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(5, (12, 8, 12)),
+ torch.randint(2, (12, 1, 12)),
+ )
+ # implicit (left-padded) broadcast
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[:, z, :, y],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(5, (12, 8, 12)),
+ torch.randint(2, (8, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[z, y, :, :],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(2, (8, 12)),
+ torch.randint(5, (12, 8, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[z, :, y, :],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(2, (8, 12)),
+ torch.randint(5, (12, 8, 12)),
+ )
+ self._test_op(
+ mesh,
+ lambda x, y, z: x[z, :, :, y],
+ torch.randn(16, 32, 16, 12),
+ torch.randint(2, (8, 1)),
+ torch.randint(5, (12, 8, 12)),
+ )
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py
new file mode 100644
index 0000000..abd0ccf
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py
@@ -0,0 +1,661 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import List
+from torch.testing._internal.common_methods_invocations import op_db, OpInfo
+
+# Generated from test/gen_dtensor_op_db.py via
+# python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
+#
+# This approach is copied from functorch:
+# People add new OpInfos to PyTorch all the time.
+# We want them to be able to add OpInfos without breaking our CI.
+# To achieve this, we keep our OpInfo library behind that of Pytorch's and
+# we periodically update our OpInfo library by regenerating this file
+_dtensor_lagging_meta = {
+ ("H", ""),
+ ("T", ""),
+ ("__getitem__", ""),
+ ("__radd__", ""),
+ ("__rand__", ""),
+ ("__rdiv__", ""),
+ ("__rmatmul__", ""),
+ ("__rmod__", ""),
+ ("__rmul__", ""),
+ ("__ror__", ""),
+ ("__rpow__", ""),
+ ("__rsub__", ""),
+ ("__rxor__", ""),
+ ("abs", ""),
+ ("acos", ""),
+ ("acosh", ""),
+ ("add", ""),
+ ("addbmm", ""),
+ ("addcdiv", ""),
+ ("addcmul", ""),
+ ("addmm", ""),
+ ("addmm", "decomposed"),
+ ("addmv", ""),
+ ("addr", ""),
+ ("all", ""),
+ ("allclose", ""),
+ ("amax", ""),
+ ("amin", ""),
+ ("aminmax", ""),
+ ("angle", ""),
+ ("any", ""),
+ ("arange", ""),
+ ("argmax", ""),
+ ("argmin", ""),
+ ("argsort", ""),
+ ("argwhere", ""),
+ ("as_strided", ""),
+ ("as_strided_scatter", ""),
+ ("asin", ""),
+ ("asinh", ""),
+ ("atan", ""),
+ ("atan2", ""),
+ ("atanh", ""),
+ ("atleast_1d", ""),
+ ("atleast_2d", ""),
+ ("atleast_3d", ""),
+ ("baddbmm", ""),
+ ("bernoulli", ""),
+ ("bfloat16", ""),
+ ("bincount", ""),
+ ("bitwise_and", ""),
+ ("bitwise_left_shift", ""),
+ ("bitwise_not", ""),
+ ("bitwise_or", ""),
+ ("bitwise_right_shift", ""),
+ ("bitwise_xor", ""),
+ ("block_diag", ""),
+ ("bmm", ""),
+ ("bool", ""),
+ ("broadcast_shapes", ""),
+ ("broadcast_tensors", ""),
+ ("broadcast_to", ""),
+ ("bucketize", ""),
+ ("byte", ""),
+ ("cartesian_prod", ""),
+ ("cat", ""),
+ ("cdist", ""),
+ ("cdouble", ""),
+ ("ceil", ""),
+ ("cfloat", ""),
+ ("chalf", ""),
+ ("char", ""),
+ ("cholesky", ""),
+ ("cholesky_inverse", ""),
+ ("cholesky_solve", ""),
+ ("chunk", ""),
+ ("clamp", ""),
+ ("clamp_max", ""),
+ ("clamp_min", ""),
+ ("clone", ""),
+ ("column_stack", ""),
+ ("combinations", ""),
+ ("complex", ""),
+ ("conj", ""),
+ ("conj_physical", ""),
+ ("constant_pad_nd", ""),
+ ("contiguous", ""),
+ ("copysign", ""),
+ ("corrcoef", ""),
+ ("cos", ""),
+ ("cosh", ""),
+ ("count_nonzero", ""),
+ ("cov", ""),
+ ("cross", ""),
+ ("cummax", ""),
+ ("cummin", ""),
+ ("cumprod", ""),
+ ("cumsum", ""),
+ ("cumulative_trapezoid", ""),
+ ("deg2rad", ""),
+ ("diag", ""),
+ ("diag_embed", ""),
+ ("diagflat", ""),
+ ("diagonal", ""),
+ ("diagonal_copy", ""),
+ ("diagonal_scatter", ""),
+ ("diff", ""),
+ ("digamma", ""),
+ ("dist", ""),
+ ("div", "floor_rounding"),
+ ("div", "no_rounding_mode"),
+ ("div", "trunc_rounding"),
+ ("dot", ""),
+ ("double", ""),
+ ("dsplit", ""),
+ ("dstack", ""),
+ ("einsum", ""),
+ ("empty", ""),
+ ("empty_like", ""),
+ ("eq", ""),
+ ("equal", ""),
+ ("erf", ""),
+ ("erfc", ""),
+ ("erfinv", ""),
+ ("exp", ""),
+ ("exp2", ""),
+ ("expand", ""),
+ ("expand_as", ""),
+ ("expm1", ""),
+ ("eye", ""),
+ ("fft.fft", ""),
+ ("fft.fft2", ""),
+ ("fft.fftn", ""),
+ ("fft.fftshift", ""),
+ ("fft.hfft", ""),
+ ("fft.hfft2", ""),
+ ("fft.hfftn", ""),
+ ("fft.ifft", ""),
+ ("fft.ifft2", ""),
+ ("fft.ifftn", ""),
+ ("fft.ifftshift", ""),
+ ("fft.ihfft", ""),
+ ("fft.ihfft2", ""),
+ ("fft.ihfftn", ""),
+ ("fft.irfft", ""),
+ ("fft.irfft2", ""),
+ ("fft.irfftn", ""),
+ ("fft.rfft", ""),
+ ("fft.rfft2", ""),
+ ("fft.rfftn", ""),
+ ("fill", ""),
+ ("flatten", ""),
+ ("flip", ""),
+ ("fliplr", ""),
+ ("flipud", ""),
+ ("float", ""),
+ ("float_power", ""),
+ ("floor", ""),
+ ("floor_divide", ""),
+ ("fmax", ""),
+ ("fmin", ""),
+ ("fmod", ""),
+ ("frac", ""),
+ ("frexp", ""),
+ ("full", ""),
+ ("full_like", ""),
+ ("gather", ""),
+ ("gcd", ""),
+ ("ge", ""),
+ ("geqrf", ""),
+ ("gradient", ""),
+ ("gt", ""),
+ ("half", ""),
+ ("heaviside", ""),
+ ("histc", ""),
+ ("histogram", ""),
+ ("histogramdd", ""),
+ ("hsplit", ""),
+ ("hstack", ""),
+ ("hypot", ""),
+ ("i0", ""),
+ ("igamma", ""),
+ ("igammac", ""),
+ ("imag", ""),
+ ("index_add", ""),
+ ("index_copy", ""),
+ ("index_fill", ""),
+ ("index_put", ""),
+ ("index_reduce", ""),
+ ("index_select", ""),
+ ("inner", ""),
+ ("int", ""),
+ ("isclose", ""),
+ ("isfinite", ""),
+ ("isin", ""),
+ ("isinf", ""),
+ ("isnan", ""),
+ ("isneginf", ""),
+ ("isposinf", ""),
+ ("isreal", ""),
+ ("istft", ""),
+ ("jiterator_2inputs_2outputs", ""),
+ ("jiterator_4inputs_with_extra_args", ""),
+ ("jiterator_binary", ""),
+ ("jiterator_binary_return_by_ref", ""),
+ ("jiterator_unary", ""),
+ ("kron", ""),
+ ("kthvalue", ""),
+ ("lcm", ""),
+ ("ldexp", ""),
+ ("le", ""),
+ ("lerp", ""),
+ ("lgamma", ""),
+ ("linalg.cholesky", ""),
+ ("linalg.cholesky_ex", ""),
+ ("linalg.cond", ""),
+ ("linalg.cross", ""),
+ ("linalg.det", ""),
+ ("linalg.det", "singular"),
+ ("linalg.eig", ""),
+ ("linalg.eigh", ""),
+ ("linalg.eigvals", ""),
+ ("linalg.eigvalsh", ""),
+ ("linalg.householder_product", ""),
+ ("linalg.inv", ""),
+ ("linalg.inv_ex", ""),
+ ("linalg.ldl_factor", ""),
+ ("linalg.ldl_factor_ex", ""),
+ ("linalg.ldl_solve", ""),
+ ("linalg.lstsq", ""),
+ ("linalg.lstsq", "grad_oriented"),
+ ("linalg.lu", ""),
+ ("linalg.lu_factor", ""),
+ ("linalg.lu_factor_ex", ""),
+ ("linalg.lu_solve", ""),
+ ("linalg.matrix_norm", ""),
+ ("linalg.matrix_power", ""),
+ ("linalg.matrix_rank", ""),
+ ("linalg.matrix_rank", "hermitian"),
+ ("linalg.multi_dot", ""),
+ ("linalg.norm", ""),
+ ("linalg.norm", "subgradients_at_zero"),
+ ("linalg.pinv", ""),
+ ("linalg.pinv", "hermitian"),
+ ("linalg.pinv", "singular"),
+ ("linalg.qr", ""),
+ ("linalg.slogdet", ""),
+ ("linalg.solve", ""),
+ ("linalg.solve_ex", ""),
+ ("linalg.solve_triangular", ""),
+ ("linalg.svd", ""),
+ ("linalg.svdvals", ""),
+ ("linalg.tensorinv", ""),
+ ("linalg.tensorsolve", ""),
+ ("linalg.vander", ""),
+ ("linalg.vecdot", ""),
+ ("linalg.vector_norm", ""),
+ ("linspace", ""),
+ ("log", ""),
+ ("log10", ""),
+ ("log1p", ""),
+ ("log2", ""),
+ ("log_softmax", ""),
+ ("log_softmax", "with_dtype"),
+ ("logaddexp", ""),
+ ("logaddexp2", ""),
+ ("logcumsumexp", ""),
+ ("logdet", ""),
+ ("logical_and", ""),
+ ("logical_not", ""),
+ ("logical_or", ""),
+ ("logical_xor", ""),
+ ("logit", ""),
+ ("logspace", ""),
+ ("logsumexp", ""),
+ ("long", ""),
+ ("lt", ""),
+ ("lu", ""),
+ ("lu_solve", ""),
+ ("lu_unpack", ""),
+ ("mH", ""),
+ ("mT", ""),
+ ("masked.amax", ""),
+ ("masked.amin", ""),
+ ("masked.argmax", ""),
+ ("masked.argmin", ""),
+ ("masked.cumprod", ""),
+ ("masked.cumsum", ""),
+ ("masked.log_softmax", ""),
+ ("masked.logaddexp", ""),
+ ("masked.logsumexp", ""),
+ ("masked.mean", ""),
+ ("masked.median", ""),
+ ("masked.norm", ""),
+ ("masked.normalize", ""),
+ ("masked.prod", ""),
+ ("masked.softmax", ""),
+ ("masked.softmin", ""),
+ ("masked.std", ""),
+ ("masked.sum", ""),
+ ("masked.var", ""),
+ ("masked_fill", ""),
+ ("masked_scatter", ""),
+ ("masked_select", ""),
+ ("matmul", ""),
+ ("matrix_exp", ""),
+ ("max", "binary"),
+ ("max", "reduction_no_dim"),
+ ("max", "reduction_with_dim"),
+ ("max_pool2d_with_indices_backward", ""),
+ ("maximum", ""),
+ ("mean", ""),
+ ("median", ""),
+ ("meshgrid", "list_of_tensors"),
+ ("meshgrid", "variadic_tensors"),
+ ("min", "binary"),
+ ("min", "reduction_no_dim"),
+ ("min", "reduction_with_dim"),
+ ("minimum", ""),
+ ("mm", ""),
+ ("mode", ""),
+ ("movedim", ""),
+ ("msort", ""),
+ ("mul", ""),
+ ("multinomial", ""),
+ ("mv", ""),
+ ("mvlgamma", "mvlgamma_p_1"),
+ ("mvlgamma", "mvlgamma_p_3"),
+ ("mvlgamma", "mvlgamma_p_5"),
+ ("nan_to_num", ""),
+ ("nanmean", ""),
+ ("nanmedian", ""),
+ ("nanquantile", ""),
+ ("nansum", ""),
+ ("narrow", ""),
+ ("narrow_copy", ""),
+ ("native_batch_norm", ""),
+ ("native_layer_norm", ""),
+ ("ne", ""),
+ ("neg", ""),
+ ("new_empty", ""),
+ ("new_empty_strided", ""),
+ ("new_full", ""),
+ ("new_ones", ""),
+ ("new_zeros", ""),
+ ("nextafter", ""),
+ ("nn.functional._scaled_dot_product_attention", ""),
+ ("nn.functional.adaptive_avg_pool1d", ""),
+ ("nn.functional.adaptive_avg_pool2d", ""),
+ ("nn.functional.adaptive_avg_pool3d", ""),
+ ("nn.functional.adaptive_max_pool1d", ""),
+ ("nn.functional.adaptive_max_pool2d", ""),
+ ("nn.functional.adaptive_max_pool3d", ""),
+ ("nn.functional.alpha_dropout", ""),
+ ("nn.functional.avg_pool1d", ""),
+ ("nn.functional.avg_pool2d", ""),
+ ("nn.functional.avg_pool3d", ""),
+ ("nn.functional.batch_norm", ""),
+ ("nn.functional.batch_norm", "without_cudnn"),
+ ("nn.functional.bilinear", ""),
+ ("nn.functional.binary_cross_entropy", ""),
+ ("nn.functional.binary_cross_entropy_with_logits", ""),
+ ("nn.functional.celu", ""),
+ ("nn.functional.conv1d", ""),
+ ("nn.functional.conv2d", ""),
+ ("nn.functional.conv_transpose1d", ""),
+ ("nn.functional.conv_transpose2d", ""),
+ ("nn.functional.conv_transpose3d", ""),
+ ("nn.functional.cosine_embedding_loss", ""),
+ ("nn.functional.cosine_similarity", ""),
+ ("nn.functional.cross_entropy", ""),
+ ("nn.functional.ctc_loss", ""),
+ ("nn.functional.dropout", ""),
+ ("nn.functional.dropout2d", ""),
+ ("nn.functional.dropout3d", ""),
+ ("nn.functional.elu", ""),
+ ("nn.functional.embedding", ""),
+ ("nn.functional.embedding_bag", ""),
+ ("nn.functional.feature_alpha_dropout", "with_train"),
+ ("nn.functional.feature_alpha_dropout", "without_train"),
+ ("nn.functional.fractional_max_pool2d", ""),
+ ("nn.functional.fractional_max_pool3d", ""),
+ ("nn.functional.gaussian_nll_loss", ""),
+ ("nn.functional.gelu", ""),
+ ("nn.functional.glu", ""),
+ ("nn.functional.grid_sample", ""),
+ ("nn.functional.group_norm", ""),
+ ("nn.functional.hardshrink", ""),
+ ("nn.functional.hardsigmoid", ""),
+ ("nn.functional.hardswish", ""),
+ ("nn.functional.hardtanh", ""),
+ ("nn.functional.hinge_embedding_loss", ""),
+ ("nn.functional.huber_loss", ""),
+ ("nn.functional.instance_norm", ""),
+ ("nn.functional.interpolate", "area"),
+ ("nn.functional.interpolate", "bicubic"),
+ ("nn.functional.interpolate", "bilinear"),
+ ("nn.functional.interpolate", "linear"),
+ ("nn.functional.interpolate", "nearest"),
+ ("nn.functional.interpolate", "trilinear"),
+ ("nn.functional.kl_div", ""),
+ ("nn.functional.l1_loss", ""),
+ ("nn.functional.layer_norm", ""),
+ ("nn.functional.leaky_relu", ""),
+ ("nn.functional.linear", ""),
+ ("nn.functional.local_response_norm", ""),
+ ("nn.functional.logsigmoid", ""),
+ ("nn.functional.margin_ranking_loss", ""),
+ ("nn.functional.max_pool1d", ""),
+ ("nn.functional.max_pool2d", ""),
+ ("nn.functional.max_pool3d", ""),
+ ("nn.functional.max_unpool1d", ""),
+ ("nn.functional.max_unpool1d", "grad"),
+ ("nn.functional.max_unpool2d", ""),
+ ("nn.functional.max_unpool2d", "grad"),
+ ("nn.functional.max_unpool3d", ""),
+ ("nn.functional.max_unpool3d", "grad"),
+ ("nn.functional.mish", ""),
+ ("nn.functional.mse_loss", ""),
+ ("nn.functional.multi_margin_loss", ""),
+ ("nn.functional.multilabel_margin_loss", ""),
+ ("nn.functional.multilabel_soft_margin_loss", ""),
+ ("nn.functional.nll_loss", ""),
+ ("nn.functional.normalize", ""),
+ ("nn.functional.one_hot", ""),
+ ("nn.functional.pad", "circular"),
+ ("nn.functional.pad", "constant"),
+ ("nn.functional.pad", "reflect"),
+ ("nn.functional.pad", "replicate"),
+ ("nn.functional.pairwise_distance", ""),
+ ("nn.functional.pdist", ""),
+ ("nn.functional.pixel_shuffle", ""),
+ ("nn.functional.pixel_unshuffle", ""),
+ ("nn.functional.poisson_nll_loss", ""),
+ ("nn.functional.prelu", ""),
+ ("nn.functional.relu", ""),
+ ("nn.functional.relu6", ""),
+ ("nn.functional.rrelu", ""),
+ ("nn.functional.selu", ""),
+ ("nn.functional.silu", ""),
+ ("nn.functional.silu", "complex"),
+ ("nn.functional.smooth_l1_loss", ""),
+ ("nn.functional.soft_margin_loss", ""),
+ ("nn.functional.softmin", ""),
+ ("nn.functional.softmin", "with_dtype"),
+ ("nn.functional.softplus", ""),
+ ("nn.functional.softshrink", ""),
+ ("nn.functional.softsign", ""),
+ ("nn.functional.tanhshrink", ""),
+ ("nn.functional.threshold", ""),
+ ("nn.functional.triplet_margin_loss", ""),
+ ("nn.functional.triplet_margin_with_distance_loss", ""),
+ ("nn.functional.unfold", ""),
+ ("nn.functional.upsample_bilinear", ""),
+ ("nn.functional.upsample_nearest", ""),
+ ("nonzero", ""),
+ ("norm", ""),
+ ("norm", "fro"),
+ ("norm", "inf"),
+ ("norm", "nuc"),
+ ("normal", ""),
+ ("normal", "number_mean"),
+ ("ones", ""),
+ ("ones_like", ""),
+ ("ormqr", ""),
+ ("outer", ""),
+ ("pca_lowrank", ""),
+ ("permute", ""),
+ ("pinverse", ""),
+ ("polar", ""),
+ ("polygamma", "polygamma_n_0"),
+ ("polygamma", "polygamma_n_1"),
+ ("polygamma", "polygamma_n_2"),
+ ("polygamma", "polygamma_n_3"),
+ ("polygamma", "polygamma_n_4"),
+ ("positive", ""),
+ ("pow", ""),
+ ("prod", ""),
+ ("put", ""),
+ ("qr", ""),
+ ("quantile", ""),
+ ("rad2deg", ""),
+ ("rand_like", ""),
+ ("randint", ""),
+ ("randint_like", ""),
+ ("randn", ""),
+ ("randn_like", ""),
+ ("ravel", ""),
+ ("real", ""),
+ ("reciprocal", ""),
+ ("remainder", ""),
+ ("renorm", ""),
+ ("repeat", ""),
+ ("repeat_interleave", ""),
+ ("reshape", ""),
+ ("reshape_as", ""),
+ ("resize_", ""),
+ ("resize_as_", ""),
+ ("resolve_conj", ""),
+ ("resolve_neg", ""),
+ ("roll", ""),
+ ("rot90", ""),
+ ("round", ""),
+ ("round", "decimals_0"),
+ ("round", "decimals_3"),
+ ("round", "decimals_neg_3"),
+ ("rsqrt", ""),
+ ("rsub", ""),
+ ("scalar_tensor", ""),
+ ("scatter", ""),
+ ("scatter_add", ""),
+ ("scatter_reduce", "amax"),
+ ("scatter_reduce", "amin"),
+ ("scatter_reduce", "mean"),
+ ("scatter_reduce", "prod"),
+ ("scatter_reduce", "sum"),
+ ("searchsorted", ""),
+ ("segment_reduce", "lengths"),
+ ("segment_reduce", "offsets"),
+ ("select", ""),
+ ("select_scatter", ""),
+ ("sgn", ""),
+ ("short", ""),
+ ("sigmoid", ""),
+ ("sign", ""),
+ ("signal.windows.cosine", ""),
+ ("signal.windows.exponential", ""),
+ ("signal.windows.gaussian", ""),
+ ("signal.windows.kaiser", ""),
+ ("signbit", ""),
+ ("sin", ""),
+ ("sinc", ""),
+ ("sinh", ""),
+ ("slice", ""),
+ ("slice_scatter", ""),
+ ("softmax", ""),
+ ("softmax", "with_dtype"),
+ ("sort", ""),
+ ("sparse.sampled_addmm", ""),
+ ("special.airy_ai", ""),
+ ("special.bessel_j0", ""),
+ ("special.bessel_j1", ""),
+ ("special.bessel_y0", ""),
+ ("special.bessel_y1", ""),
+ ("special.chebyshev_polynomial_t", ""),
+ ("special.chebyshev_polynomial_u", ""),
+ ("special.chebyshev_polynomial_v", ""),
+ ("special.chebyshev_polynomial_w", ""),
+ ("special.entr", ""),
+ ("special.erfcx", ""),
+ ("special.hermite_polynomial_h", ""),
+ ("special.hermite_polynomial_he", ""),
+ ("special.i0e", ""),
+ ("special.i1", ""),
+ ("special.i1e", ""),
+ ("special.laguerre_polynomial_l", ""),
+ ("special.legendre_polynomial_p", ""),
+ ("special.log_ndtr", ""),
+ ("special.modified_bessel_i0", ""),
+ ("special.modified_bessel_i1", ""),
+ ("special.modified_bessel_k0", ""),
+ ("special.modified_bessel_k1", ""),
+ ("special.ndtr", ""),
+ ("special.ndtri", ""),
+ ("special.polygamma", "special_polygamma_n_0"),
+ ("special.scaled_modified_bessel_k0", ""),
+ ("special.scaled_modified_bessel_k1", ""),
+ ("special.shifted_chebyshev_polynomial_t", ""),
+ ("special.shifted_chebyshev_polynomial_u", ""),
+ ("special.shifted_chebyshev_polynomial_v", ""),
+ ("special.shifted_chebyshev_polynomial_w", ""),
+ ("special.spherical_bessel_j0", ""),
+ ("special.xlog1py", ""),
+ ("special.zeta", ""),
+ ("split", ""),
+ ("split", "list_args"),
+ ("split_with_sizes", ""),
+ ("sqrt", ""),
+ ("square", ""),
+ ("squeeze", ""),
+ ("stack", ""),
+ ("std", ""),
+ ("std_mean", ""),
+ ("stft", ""),
+ ("sub", ""),
+ ("sum", ""),
+ ("sum_to_size", ""),
+ ("svd", ""),
+ ("svd_lowrank", ""),
+ ("symeig", ""),
+ ("t", ""),
+ ("take", ""),
+ ("take_along_dim", ""),
+ ("tan", ""),
+ ("tanh", ""),
+ ("tensor_split", ""),
+ ("tensordot", ""),
+ ("tile", ""),
+ ("to", ""),
+ ("to_sparse", ""),
+ ("topk", ""),
+ ("trace", ""),
+ ("transpose", ""),
+ ("trapezoid", ""),
+ ("trapz", ""),
+ ("triangular_solve", ""),
+ ("tril", ""),
+ ("tril_indices", ""),
+ ("triu", ""),
+ ("triu_indices", ""),
+ ("true_divide", ""),
+ ("trunc", ""),
+ ("unbind", ""),
+ ("unflatten", ""),
+ ("unfold", ""),
+ ("unfold_copy", ""),
+ ("uniform", ""),
+ ("unique", ""),
+ ("unique_consecutive", ""),
+ ("unsqueeze", ""),
+ ("var", ""),
+ ("var_mean", ""),
+ ("vdot", ""),
+ ("view", ""),
+ ("view_as", ""),
+ ("view_as_complex", ""),
+ ("view_as_real", ""),
+ ("vsplit", ""),
+ ("vstack", ""),
+ ("where", ""),
+ ("xlogy", ""),
+ ("zero_", ""),
+ ("zeros", ""),
+ ("zeros_like", ""),
+}
+
+
+def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool:
+ return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta
+
+
+dtensor_lagging_op_db: List[OpInfo] = [
+ opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo)
+]
diff --git a/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py
new file mode 100644
index 0000000..f684f77
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py
@@ -0,0 +1,67 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple
+from torch.testing._internal.common_methods_invocations import op_db
+
+
+def num_leading_spaces(line: str) -> int:
+ result = len(line) - len(line.lstrip())
+ # Empty space handling
+ if result == 0:
+ return 999999
+ return result
+
+
+def deindent(code: str) -> str:
+ lines = code.split("\n")
+ min_leading_spaces = min(map(num_leading_spaces, lines))
+ lines = [line[min_leading_spaces:] for line in lines]
+ return "\n".join(lines)
+
+
+if __name__ == "__main__":
+ supported: List[Tuple[str, str]] = [
+ (opinfo.name, opinfo.variant_test_name) for opinfo in op_db
+ ]
+ supported = sorted(supported)
+ print(
+ deindent(
+ """\
+ # Copyright (c) Facebook, Inc. and its affiliates.
+ # All rights reserved.
+ #
+ # This source code is licensed under the BSD-style license found in the
+ # LICENSE file in the root directory of this source tree.
+ from typing import List
+ from torch.testing._internal.common_methods_invocations import op_db, OpInfo
+ # Generated from test/gen_dtensor_op_db.py via
+ # python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
+ #
+ # This approach is copied from functorch:
+ # People add new OpInfos to PyTorch all the time.
+ # We want them to be able to add OpInfos without breaking our CI.
+ # To achieve this, we keep our OpInfo library behind that of Pytorch's and
+ # we periodically update our OpInfo library by regenerating this file"""
+ )
+ )
+
+ print("_dtensor_lagging_meta = {")
+ for name, variant in supported:
+ print(f" {(name, variant)},")
+ print("}")
+
+ print(
+ deindent(
+ """\
+ def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool:
+ return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta
+
+ dtensor_lagging_op_db: List[OpInfo] = [
+ opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo)
+ ]"""
+ )
+ )