[XLA] Correctly propagate builder and partition_id for inner_dot_builder.

Fixes crash when some order of partitioning is triggered.

PiperOrigin-RevId: 359869663
Change-Id: I9704fa1d0b2e53ed1f6926ceb2e120aec8618341
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 7e4d759..12530f3 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -2237,7 +2237,10 @@
     }
     // Use resharding to slice the output. Use a temporary reshard cache since
     // we are faking with replicated sharding.
-    auto new_state = lhs.state();
+    PartitionedHlo::PartitioningState new_state = lhs.state();
+    new_state.b = b;
+    new_state.partition_id =
+        lhs.state().collective_ops_creator.create_partition_id(b);
     PartitionedHlo::ReshardCache tmp_cache;
     new_state.reshard_cache = &tmp_cache;
     ar->set_sharding(HloSharding::Replicate());
@@ -2246,8 +2249,9 @@
             output_sharding, get_non_slice_dims()))
         .hlo();
   };
-  auto inner_state = CreatePerGroupPartitioningState(
-      lhs.state(), lhs_grouped.device_groups, b);
+  PartitionedHlo::PartitioningState inner_state =
+      CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
+                                      b);
   TF_ASSIGN_OR_RETURN(
       auto dot,
       PartitionDot(