[dtensor] switch pointwise op tests to use DTensorOpsTestBase (#92197)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92197
Approved by: https://github.com/XilunWu
diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py
index fca2d0d..19cf9e4 100644
--- a/test/distributed/_tensor/test_pointwise_ops.py
+++ b/test/distributed/_tensor/test_pointwise_ops.py
@@ -19,12 +19,9 @@
from torch.distributed.distributed_c10d import ReduceOp
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorOpTestBase,
skip_unless_torch_gpu,
)
-from torch.testing._internal.common_distributed import (
- MultiThreadedTestCase,
- DEFAULT_WORLD_SIZE,
-)
def no_op():
return None
@@ -74,18 +71,7 @@
return pytree.tree_map(f, [val])[0]
-class DistElementwiseOpsTest(MultiThreadedTestCase):
- @property
- def world_size(self) -> int:
- return DEFAULT_WORLD_SIZE
-
- @property
- def device_type(self) -> str:
- return "cuda" if torch.cuda.is_available() else "cpu"
-
- def build_device_mesh(self):
- return DeviceMesh(self.device_type, list(range(self.world_size)))
-
+class DistElementwiseOpsTest(DTensorOpTestBase):
def _compare_pairwise_ops(
self,
*,
diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
index 8ad6b0e..21a23bc 100644
--- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py
+++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -180,6 +180,13 @@
def world_size(self) -> int:
return NUM_DEVICES
+ @property
+ def device_type(self) -> str:
+ return "cuda" if torch.cuda.is_available() else "cpu"
+
+ def build_device_mesh(self):
+ return DeviceMesh(self.device_type, list(range(self.world_size)))
+
def setUp(self) -> None:
super().setUp()
self._spawn_threads()