[xla:jitrt] Relax shape comparison to ShapeUtil::ReshapeIsBitcast.
Consider the shape [1][1][1] with minor_to_major {0, 1, 2}. In actual memory layout, it's exactly the same as the shape [1][1][1] with minor_to_major {2, 1, 0}. However, ShapeUtil::Equal treats these layouts as different things. This subtle difference doesn't matter for memcpy.
PiperOrigin-RevId: 465174265
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index 2167c83..81fb31f 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -929,7 +929,7 @@
// Check that destination shape matches the source shape.
Shape dest_shape = ToShape(*dest);
- if (!ShapeUtil::Equal(dest_shape, source_shape)) {
+ if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
return MakeStringError(
"The destination shape does not match the source shape");
}
@@ -998,7 +998,7 @@
// Check that destination shape matches the source shape.
Shape source_shape = ToShape(*source);
- if (!ShapeUtil::Equal(dest_shape, source_shape)) {
+ if (!ShapeUtil::ReshapeIsBitcast(dest_shape, source_shape)) {
return MakeStringError(
"The destination shape does not match the source shape");
}