[spmd] Add a few more loss ops to the reduction op list (#99900)

This PR adds a few more loss ops to the reduction op list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99900
Approved by: https://github.com/mrshenli
diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py
index 0b885cd..28c42c5 100644
--- a/torch/distributed/_spmd/data_parallel.py
+++ b/torch/distributed/_spmd/data_parallel.py
@@ -155,6 +155,19 @@
         # batch dim size is used to track the batch dim size of the input tensor
         self.batch_dim_size = -1
 
+        # reduction ops that is used to detect whether there's a reduction over batch
+        # dimension operation, if there is, we mark the output as sharded instead of
+        # partial placement
+        self.reduction_ops = [
+            aten.binary_cross_entropy.default,
+            aten.binary_cross_entropy_with_logits.default,
+            aten.mean.default,
+            aten.mse_loss.default,
+            aten.nll_loss_forward.default,
+            aten.soft_margin_loss.default,
+            aten.sum.default,
+        ]
+
     def init_batch_dim_size(self, batch_dim_size: int) -> None:
         """
         initialize batch dim size base on the first input batch size
@@ -220,9 +233,8 @@
         # do full reduction across batch dimension, it would still
         # keep the reduction activation as sharded.
         reduction_over_batch = False
-        reduction_ops = [aten.sum.default, aten.mean.default]
         shape = node.meta["val"].shape
-        if node.target in reduction_ops and len(shape) == 0:
+        if node.target in self.reduction_ops and len(shape) == 0:
             operand = node.all_input_nodes[0]
             if operand in self.batch_dim_map:
                 operand_batch_dim = self.get_batch_dim(operand)