We just need to copy the tiling info of result shape in layout assignment instead
the whole shape since the shapes already match.
PiperOrigin-RevId: 357367492
Change-Id: I05be8ebe78265f6d5d4de00c093b8008631275b2
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 4882c5d..72561cb 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1949,9 +1949,19 @@
computation->root_instruction()));
computation->set_root_instruction(new_root);
} else {
- // Use the specified shape including tiling info in layout.
- *(computation->root_instruction()->mutable_shape()) =
- constraints.ResultLayout()->shape();
+ // Copy the specified tiling info.
+ auto assign_tiling = [&constraints](xla::Shape* subshape,
+ const xla::ShapeIndex& index) {
+ if (subshape->IsArray()) {
+ const Shape& result_shape = ShapeUtil::GetSubshape(
+ constraints.ResultLayout()->shape(), index);
+ subshape->mutable_layout()->mutable_tiles()->assign(
+ result_shape.layout().tiles().begin(),
+ result_shape.layout().tiles().end());
+ }
+ };
+ xla::ShapeUtil::ForEachMutableSubshape(
+ computation->root_instruction()->mutable_shape(), assign_tiling);
}
}
return Status::OK();