[dtensor] make add_.Tensor/div_.Scalar to be linear pointwise instead (#121294)

add_.Tensor and div_.Scalar should support linearity so that we delay the partial
results.

This fixes the additional collective in the layernorm layer that we seen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121294
Approved by: https://github.com/tianyu-l
diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py
index 4fd45fe..3ae03b4 100644
--- a/test/distributed/_tensor/test_math_ops.py
+++ b/test/distributed/_tensor/test_math_ops.py
@@ -296,7 +296,7 @@
         # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
         batch, sentence_length, embedding_dim = 20, 5, 10
         norm_shape_idx_list = list(range(3))
-        shard_dims = [-1, 0, 1, 2]
+        shard_dims = [0, 1, 2]
         elementwise_affine_list = [False, True]
         test_config_list = list(
             itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
@@ -347,9 +347,10 @@
             with comm_mode:
                 y_dist = layer_norm_dist(x_dist)
 
-            self.assertLessEqual(
+            expected_fwd_comm = 0 if shard_dim < norm_idx else 1
+            self.assertEqual(
                 comm_mode.get_total_counts(),
-                1,  # TODO: This should be 0!
+                expected_fwd_comm,
                 f"comm count={comm_mode.get_total_counts()}, "
                 f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
             )
@@ -361,9 +362,11 @@
             with comm_mode:
                 y_dist.sum().backward()
 
-            self.assertLessEqual(
+            expected_bwd_comm = 0 if shard_dim < norm_idx else 1
+
+            self.assertEqual(
                 comm_mode.get_total_counts(),
-                1,  # TODO: This should be 0!
+                expected_bwd_comm,
                 f"comm count={comm_mode.get_total_counts()}, "
                 f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
             )
diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py
index f67567d..dd9ca8e 100644
--- a/torch/distributed/_tensor/ops/pointwise_ops.py
+++ b/torch/distributed/_tensor/ops/pointwise_ops.py
@@ -51,8 +51,10 @@
 
 linear_pointwise_ops = [
     aten.div.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
+    aten.div_.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
     aten.to.dtype,
     aten.add.Tensor,
+    aten.add_.Tensor,
 ]
 
 
@@ -70,7 +72,6 @@
     aten.add.Scalar,
     aten.add.out,
     aten.add_.Scalar,
-    aten.add_.Tensor,
     aten.addcdiv.default,
     aten.addcdiv.out,
     aten.addcdiv_.default,