[spmd] Add a few more loss ops to the reduction op list (#99900)
This PR adds a few more loss ops to the reduction op list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99900
Approved by: https://github.com/mrshenli
diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py
index 0b885cd..28c42c5 100644
--- a/torch/distributed/_spmd/data_parallel.py
+++ b/torch/distributed/_spmd/data_parallel.py
@@ -155,6 +155,19 @@
# batch dim size is used to track the batch dim size of the input tensor
self.batch_dim_size = -1
+ # reduction ops that is used to detect whether there's a reduction over batch
+ # dimension operation, if there is, we mark the output as sharded instead of
+ # partial placement
+ self.reduction_ops = [
+ aten.binary_cross_entropy.default,
+ aten.binary_cross_entropy_with_logits.default,
+ aten.mean.default,
+ aten.mse_loss.default,
+ aten.nll_loss_forward.default,
+ aten.soft_margin_loss.default,
+ aten.sum.default,
+ ]
+
def init_batch_dim_size(self, batch_dim_size: int) -> None:
"""
initialize batch dim size base on the first input batch size
@@ -220,9 +233,8 @@
# do full reduction across batch dimension, it would still
# keep the reduction activation as sharded.
reduction_over_batch = False
- reduction_ops = [aten.sum.default, aten.mean.default]
shape = node.meta["val"].shape
- if node.target in reduction_ops and len(shape) == 0:
+ if node.target in self.reduction_ops and len(shape) == 0:
operand = node.all_input_nodes[0]
if operand in self.batch_dim_map:
operand_batch_dim = self.get_batch_dim(operand)