[XLA] Correctly propagate builder and partition_id for inner_dot_builder.
Fixes crash when some order of partitioning is triggered.
PiperOrigin-RevId: 359869663
Change-Id: I9704fa1d0b2e53ed1f6926ceb2e120aec8618341
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 7e4d759..12530f3 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -2237,7 +2237,10 @@
}
// Use resharding to slice the output. Use a temporary reshard cache since
// we are faking with replicated sharding.
- auto new_state = lhs.state();
+ PartitionedHlo::PartitioningState new_state = lhs.state();
+ new_state.b = b;
+ new_state.partition_id =
+ lhs.state().collective_ops_creator.create_partition_id(b);
PartitionedHlo::ReshardCache tmp_cache;
new_state.reshard_cache = &tmp_cache;
ar->set_sharding(HloSharding::Replicate());
@@ -2246,8 +2249,9 @@
output_sharding, get_non_slice_dims()))
.hlo();
};
- auto inner_state = CreatePerGroupPartitioningState(
- lhs.state(), lhs_grouped.device_groups, b);
+ PartitionedHlo::PartitioningState inner_state =
+ CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
+ b);
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDot(