Print layout when checking parallel device shape matching.
PiperOrigin-RevId: 454189228
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index b80d9f2..d3775ab 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1354,15 +1354,16 @@
"The %s expects the async shape at index {0} to match async "
"computation parameter shape (%s vs %s).",
HloOpcodeString(async_op->opcode()),
- async_shape.tuple_shapes(0).ToString(), param_shape.ToString());
+ async_shape.tuple_shapes(0).ToString(/*print_layout=*/true),
+ param_shape.ToString(/*print_layout=*/true));
}
if (async_shape.tuple_shapes(1) != computation_shape.result()) {
return InternalError(
"The %s expects the async shape at index {1} to match the async "
"computation root shape (%s vs %s).",
HloOpcodeString(async_op->opcode()),
- async_shape.tuple_shapes(1).ToString(),
- computation_shape.result().ToString());
+ async_shape.tuple_shapes(1).ToString(/*print_layout=*/true),
+ computation_shape.result().ToString(/*print_layout=*/true));
}
return Status::OK();
}
@@ -1378,8 +1379,8 @@
"The %s expects the shape of operand %d to match the async shape at "
"index {0} (%s vs %s).",
HloOpcodeString(async_start->opcode()), i,
- async_start->operand(i)->shape().ToString(),
- param_shape.tuple_shapes(i).ToString());
+ async_start->operand(i)->shape().ToString(/*print_layout=*/true),
+ param_shape.tuple_shapes(i).ToString(/*print_layout=*/true));
}
}
return Status::OK();