[dtensor] switch pointwise op tests to use DTensorOpsTestBase (#92197)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92197
Approved by: https://github.com/XilunWu
diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py
index fca2d0d..19cf9e4 100644
--- a/test/distributed/_tensor/test_pointwise_ops.py
+++ b/test/distributed/_tensor/test_pointwise_ops.py
@@ -19,12 +19,9 @@
 from torch.distributed.distributed_c10d import ReduceOp
 from torch.testing._internal.common_utils import run_tests
 from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorOpTestBase,
     skip_unless_torch_gpu,
 )
-from torch.testing._internal.common_distributed import (
-    MultiThreadedTestCase,
-    DEFAULT_WORLD_SIZE,
-)
 
 def no_op():
     return None
@@ -74,18 +71,7 @@
     return pytree.tree_map(f, [val])[0]
 
 
-class DistElementwiseOpsTest(MultiThreadedTestCase):
-    @property
-    def world_size(self) -> int:
-        return DEFAULT_WORLD_SIZE
-
-    @property
-    def device_type(self) -> str:
-        return "cuda" if torch.cuda.is_available() else "cpu"
-
-    def build_device_mesh(self):
-        return DeviceMesh(self.device_type, list(range(self.world_size)))
-
+class DistElementwiseOpsTest(DTensorOpTestBase):
     def _compare_pairwise_ops(
         self,
         *,
diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
index 8ad6b0e..21a23bc 100644
--- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py
+++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -180,6 +180,13 @@
     def world_size(self) -> int:
         return NUM_DEVICES
 
+    @property
+    def device_type(self) -> str:
+        return "cuda" if torch.cuda.is_available() else "cpu"
+
+    def build_device_mesh(self):
+        return DeviceMesh(self.device_type, list(range(self.world_size)))
+
     def setUp(self) -> None:
         super().setUp()
         self._spawn_threads()