Some legacy tests have inconsistencies layout when compiling the infeed
op and transferring infeed data. Adding this api allows the tests to
get the layout right for transferring data.
PiperOrigin-RevId: 431494217
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 8749439..59f4f16 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -423,4 +423,8 @@
return LayoutUtil::GetWithDefaultLayout(host_shape);
}
+xla::Shape TransferManager::ChooseGoodInfeedLayout(const Shape& shape) const {
+ return LayoutUtil::GetWithDefaultLayout(shape);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index f5c2ddb..4a86c5b 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -241,6 +241,11 @@
virtual StatusOr<Shape> ChooseCompactLayoutForShape(
const Shape& host_shape) const;
+ // For the given shape, chooses a layout for infeed. The returned shape
+ // has the same dimensions as the original shape, and only the layout is
+ // changed.
+ virtual Shape ChooseGoodInfeedLayout(const Shape& shape) const;
+
typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn;
// Allocates a ScopedShapedBuffer which can hold data with the given on-host