Breaks TF Serving build and IREE OSS build.
PiperOrigin-RevId: 410414971
Change-Id: I5cf27ba27291f564395c34c1246cfbc3b4aa505e
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index b25285d..f1f2592 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1490,20 +1490,6 @@
`call_target_name` should be short as it may be used in labels.
`backend_config` can encode arbitrarily large amounts of information.
- A custom call can also have layout constraints on operands and results which
- can be specified as optional `operand_layouts` and `result_layouts`
- attributes. The layout attribute is an array of rank-1 index tensors and the
- i-th layout attribute specifies the layout for i-th operand/result.
-
- The `operand_layouts` & `result_layouts` attributes can be specified under
- the following constraints:
- 1) Either both `operand_layouts` and `result_layouts` are specified or none.
- 2) None of the operands are of tuple type.
- 3) None of the results are of tuple type except the common case of single
- tuple result packing non-tuple values is allowed. In this case the i-th
- `result_layouts` attribute specifies the layout of i-th element in the
- result tuple.
-
See https://www.tensorflow.org/xla/operation_semantics#customcall.
}];
let arguments = (ins
@@ -1515,9 +1501,7 @@
// the status-returning API.
DefaultValuedAttr<HLO_CustomCallApiVersionAttr,
"CustomCallApiVersion::API_VERSION_ORIGINAL">:
- $api_version,
- OptionalAttr<HLO_ArrayOfLayoutAttr>:$operand_layouts,
- OptionalAttr<HLO_ArrayOfLayoutAttr>:$result_layouts
+ $api_version
);
let results = (outs Variadic<HLO_TensorOrTokenOrTuple>);
let hasCustomHLOConverter = 1;
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
index 9da77ef..0b0d517 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
@@ -111,11 +111,6 @@
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}
-// An array of layout (1D tensor) attributes.
-def HLO_ArrayOfLayoutAttr : TypedArrayAttrBase<HLO_LayoutAttr,
- "Array of layout (1D tensor of index type) attributes">;
-
-
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index c96fe77..fac5b2c 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -249,101 +249,6 @@
}
//===----------------------------------------------------------------------===//
-// CustomCallOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult Verify(CustomCallOp op) {
- // If both operand and result layout attributes are not specified then nothing
- // to verify.
- if (!op.operand_layouts().hasValue() && !op.result_layouts().hasValue())
- return success();
-
- // Layout constraints for either both operands & results or none should be
- // specified.
- if (op.operand_layouts().hasValue() != op.result_layouts().hasValue())
- return op.emitOpError() << "Layout attributes should be specified for "
- "either both operands and results or none.";
-
- // Helper function to verify types and the corresponding layouts.
- auto verify_types_and_layouts =
- [&op](TypeRange types, mlir::ArrayAttr layouts,
- const std::string& value_name) -> LogicalResult {
- if (types.size() != layouts.size())
- return op.emitOpError()
- << "Number of " << value_name << "s must match the number of "
- << value_name << " layouts, " << types.size()
- << " != " << layouts.size();
-
- for (const auto& indexed_type_and_layout :
- llvm::enumerate(llvm::zip(types, layouts))) {
- // Get index for more descriptive error message.
- auto index = indexed_type_and_layout.index();
-
- auto type = std::get<0>(indexed_type_and_layout.value());
- auto layout = std::get<1>(indexed_type_and_layout.value())
- .cast<DenseIntElementsAttr>();
-
- if (type.isa<TupleType>())
- return op.emitOpError() << "Tuple types are not fully supported with "
- "layout constraints yet";
- auto tensor_type = type.dyn_cast<TensorType>();
-
- // For non-tensor types such as !mhlo.token, the layout should be empty.
- if (!tensor_type) {
- if (layout.empty()) continue;
- return op.emitOpError()
- << "Only tensor types can have non-empty layout: " << value_name
- << " #" << index << " of type " << type << " has layout "
- << layout;
- }
-
- // For unranked tensors, we cannot verify the compatibility with layout
- // any further.
- if (!tensor_type.hasRank()) continue;
-
- // Layout must be a permutation of [0, N) where N is the rank of the
- // tensor type.
- auto range = llvm::iota_range<unsigned>(0, tensor_type.getRank(),
- /*Inclusive=*/false);
- if (tensor_type.getRank() != layout.size() ||
- !std::is_permutation(range.begin(), range.end(), layout.begin()))
- return op.emitOpError()
- << "incorrect layout " << layout << " for type " << type
- << ", layout must be a permutation of [0, "
- << tensor_type.getRank() << ")";
- }
- return success();
- };
-
- // At this point both `operand_layouts` and `result_layouts` are defined.
- ArrayAttr operand_layouts = op.operand_layouts().getValue();
- ArrayAttr result_layouts = op.result_layouts().getValue();
-
- // Full support for layouts for arbitrary nesting of tuples is not
- // supported yet.
- //
- // If result does not have any tuples, then i-th element of `result_layouts`
- // specifies the layout constraints on i-th result.
- //
- // For the common case of a single tuple result packing non-tuple values, the
- // i-th element of `result_layouts` specifies layout for i-th element of the
- // result tuple.
- TypeRange result_types;
- if (op->getNumResults() == 1 && op->getResult(0).getType().isa<TupleType>())
- result_types = op->getResult(0).getType().cast<TupleType>().getTypes();
- else
- result_types = op->getResultTypes();
-
- // Verify that operands and operand layouts match.
- if (failed(verify_types_and_layouts(op->getOperandTypes(), operand_layouts,
- "operand")))
- return failure();
-
- // Verify that results and result layouts match.
- return verify_types_and_layouts(result_types, result_layouts, "result");
-}
-
-//===----------------------------------------------------------------------===//
// DotGeneralOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
index 3588699..4d7e091 100644
--- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
@@ -2002,125 +2002,6 @@
// -----
-// CHECK: func @custom_call_multiple_inputs_outputs_with_layout
-func @custom_call_multiple_inputs_outputs_with_layout(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<f32> {
- %0:3 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
- result_layouts = [dense<> : tensor<0xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
- } : (tensor<2xf32>, !mhlo.token) -> (tensor<f32>, tensor<2xf32>, !mhlo.token)
- return %0#0 : tensor<f32>
-}
-
-// -----
-
-// CHECK: func @custom_call_tuple_output_with_layout
-func @custom_call_tuple_output_with_layout(%x: tensor<2xf32>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
- %0 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
- result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
- } : (tensor<2xf32>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
- return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_only_operand_layout_constraints(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<2xf32> {
- // expected-error@+1 {{Layout attributes should be specified for either both operands and results or none}}
- %0:3 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
- } : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
- return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_operands(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<2xf32> {
- // expected-error@+1 {{Number of operands must match the number of operand layouts, 2 != 1}}
- %0:3 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>],
- result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
- } : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
- return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_results() -> tensor<2xf32> {
- // expected-error@+1 {{Number of results must match the number of result layouts, 3 != 2}}
- %0:3 = "mhlo.custom_call"() {
- call_target_name = "foo",
- operand_layouts = [],
- result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>]
- } : () -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
- return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_results_tuple(%x: tensor<2xf32>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
- // expected-error@+1 {{Number of results must match the number of result layouts, 3 != 2}}
- %0 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
- result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>]
- } : (tensor<2xf32>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
- return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_tuple_operand_input(%x: tuple<tensor<2xf32>>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
- // expected-error@+1 {{Tuple types are not fully supported with layout constraints yet}}
- %0 = "mhlo.custom_call"(%x, %token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
- result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
- } : (tuple<tensor<2xf32>>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
- return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_token_with_layout(%token: !mhlo.token) {
- // expected-error@+1 {{Only tensor types can have non-empty layout: operand #0 of type '!mhlo.token' has layout dense<[0, 1]> : tensor<2xindex>}}
- "mhlo.custom_call"(%token) {
- call_target_name = "foo",
- operand_layouts = [dense<[0, 1]> : tensor<2xindex>],
- result_layouts = []
- } : (!mhlo.token) -> ()
- return
-}
-
-// -----
-
-func @custom_call_mismatch_tensor_and_layout_rank(%arg: tensor<2x3xf32>) {
- // expected-error@+1 {{incorrect layout dense<[0, 1, 2]> : tensor<3xindex> for type 'tensor<2x3xf32>', layout must be a permutation of [0, 2)}}
- "mhlo.custom_call"(%arg) {
- call_target_name = "foo",
- operand_layouts = [dense<[0, 1, 2]> : tensor<3xindex>],
- result_layouts = []
- } : (tensor<2x3xf32>) -> ()
- return
-}
-
-// -----
-
-func @custom_call_mismatch_tensor_and_layout_permutation(%arg: tensor<1x2x3xf32>) {
- // expected-error@+1 {{incorrect layout dense<[0, 1, 3]> : tensor<3xindex> for type 'tensor<1x2x3xf32>', layout must be a permutation of [0, 3)}}
- "mhlo.custom_call"(%arg) {
- call_target_name = "foo",
- operand_layouts = [dense<[0, 1, 3]> : tensor<3xindex>],
- result_layouts = []
- } : (tensor<1x2x3xf32>) -> ()
- return
-}
-
-// -----
-
// CHECK: func @reduce_window
func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
%0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index d39b628..6780ec4 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -597,7 +597,6 @@
hdrs = ["attribute_importer.h"],
deps = [
"//tensorflow/compiler/mlir/hlo",
- "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc
index 4fa5d5f..356fd9a 100644
--- a/tensorflow/compiler/mlir/xla/attribute_importer.cc
+++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc
@@ -151,47 +151,4 @@
}
}
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
- const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
- std::vector<mlir::Attribute> layouts;
- for (auto& shape_and_layout : shapes_with_layouts) {
- if (shape_and_layout.IsTuple())
- return tensorflow::errors::Unimplemented(
- "Layout support for nested tuples is not implemented.");
- const xla::Layout& xla_layout = shape_and_layout.layout();
-
- // XLA can have invalid layout for certain values (such as token types).
- // These are imported as empty layout in MHLO.
- if (xla_layout.format() == xla::Format::INVALID_FORMAT) {
- layouts.push_back(builder->getIndexTensorAttr({}));
- continue;
- }
-
- // Only a subset of layout specification in XLA is supported in MHLO
- // currently. The layout has to be dense, and only specify the order of
- // dimensions. Sparse, tiled layout or non-default memory space fields
- // cannot be expressed in MHLO layout yet.
- if (xla_layout.format() != xla::Format::DENSE)
- return tensorflow::errors::Unimplemented("Unexpected layout format");
- if (!xla_layout.tiles().empty())
- return tensorflow::errors::Unimplemented(
- "Tiled layout is not supported yet");
- if (xla_layout.memory_space() != xla::Layout::kDefaultMemorySpace)
- return tensorflow::errors::Unimplemented(
- "Layout support for non-default memory space is not yet implemented");
-
- llvm::SmallVector<int64_t> layout;
- for (int64_t dim_index : xla_layout.minor_to_major())
- layout.push_back(dim_index);
- layouts.push_back(builder->getIndexTensorAttr(layout));
- }
- return builder->getArrayAttr(layouts);
-}
-
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const Shape shape,
- mlir::Builder* builder) {
- if (!shape.IsTuple()) return InvalidArgument("Expected shape to be Tuple");
- return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.h b/tensorflow/compiler/mlir/xla/attribute_importer.h
index 0e6709e..39ace0f 100644
--- a/tensorflow/compiler/mlir/xla/attribute_importer.h
+++ b/tensorflow/compiler/mlir/xla/attribute_importer.h
@@ -20,7 +20,6 @@
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
-#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -54,17 +53,6 @@
StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
xla::CustomCallApiVersion api_version);
-// Extracts layouts from shapes and converts it into layout attributes (array of
-// rank-1 index tensors). Returns an error if any of the shapes is a tuple.
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
- const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder);
-
-// Extracts the layouts of each element from a tuple shape and returns them as
-// an array of rank-1 index tensors. Returns an error in presence of nested
-// tuple shapes.
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const xla::Shape shape,
- mlir::Builder* builder);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 4749456..56aa436 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -347,27 +347,6 @@
}
case HloOpcode::kCustomCall: {
auto custom_call = Cast<HloCustomCallInstruction>(instruction);
- if (custom_call->layout_constrained()) {
- TF_ASSIGN_OR_RETURN(
- mlir::ArrayAttr operand_layouts,
- ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
- builder_));
- attributes.push_back(
- builder_->getNamedAttr("operand_layouts", operand_layouts));
- mlir::ArrayAttr result_layouts;
- if (custom_call->shape().IsTuple()) {
- TF_ASSIGN_OR_RETURN(
- result_layouts,
- ExtractLayoutsFromTuple(custom_call->shape(), builder_));
- } else {
- TF_ASSIGN_OR_RETURN(
- result_layouts,
- ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
- }
- attributes.push_back(
- builder_->getNamedAttr("result_layouts", result_layouts));
- }
-
TF_ASSIGN_OR_RETURN(
auto mlir_api_version,
ConvertCustomCallApiVersion(custom_call->api_version()));
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 084ae21..8277265 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -139,20 +139,9 @@
const Literal* literal, absl::optional<Window> window,
absl::optional<ConvolutionDimensionNumbers> dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
- mlir::ArrayAttr operand_layouts;
- mlir::ArrayAttr result_layouts;
- if (operand_shapes_with_layout.has_value()) {
- TF_ASSIGN_OR_RETURN(operand_layouts,
- ExtractLayoutsFromShapes(
- operand_shapes_with_layout.value(), &builder_));
- if (shape.IsTuple()) {
- TF_ASSIGN_OR_RETURN(result_layouts,
- ExtractLayoutsFromTuple(shape, &builder_));
- } else {
- TF_ASSIGN_OR_RETURN(result_layouts,
- ExtractLayoutsFromShapes({shape}, &builder_));
- }
- }
+ if (operand_shapes_with_layout.has_value())
+ return Unimplemented(
+ "CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
TF_ASSIGN_OR_RETURN(auto mlir_api_version,
@@ -173,8 +162,7 @@
builder_.getStringAttr(opaque),
/*api_version=*/
mlir::mhlo::CustomCallApiVersionAttr::get(builder_.getContext(),
- mlir_api_version),
- operand_layouts, result_layouts);
+ mlir_api_version));
return MakeXlaOp(op.getResult(0));
}
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 9bc19d7..49b034b 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -183,40 +183,6 @@
return xla::ConvertReplicaGroups(groups).ValueOrDie();
}
-// Converts types and corresponding layouts into xla shapes with layouts.
-static std::vector<xla::Shape> ConvertTypesToShapesWithLayout(
- mlir::TypeRange value_types, mlir::ArrayAttr layouts) {
- std::vector<xla::Shape> shapes_with_layout;
- for (auto type_and_layout : llvm::zip(value_types, layouts)) {
- mlir::Type type = std::get<0>(type_and_layout);
- mlir::Attribute layout = std::get<1>(type_and_layout);
- assert(!type.isa<mlir::TupleType>() &&
- "Exporting layout for tuples is not implemented yet");
- shapes_with_layout.emplace_back(xla::TypeToShape(type));
- auto& shape = shapes_with_layout.back();
- shape.mutable_layout()->clear_minor_to_major();
- for (auto l : layout.cast<mlir::DenseIntElementsAttr>()) {
- shape.mutable_layout()->mutable_minor_to_major()->push_back(
- l.getSExtValue());
- }
- }
- return shapes_with_layout;
-}
-
-// CustomCallOp result can be of tuple type to pack multiple results into one
-// value. If the custom call result is a tuple, then result layouts represent
-// the layout of each element of the tuple. Nested tuples are currently not
-// supported for export.
-static xla::Shape GetCustomCallResultShapeWithLayout(mlir::Type type,
- mlir::ArrayAttr layouts) {
- auto tuple_type = type.dyn_cast<mlir::TupleType>();
- if (!tuple_type) return ConvertTypesToShapesWithLayout({type}, layouts)[0];
-
- std::vector<xla::Shape> shapes_with_layouts =
- ConvertTypesToShapesWithLayout(tuple_type.getTypes(), layouts);
- return xla::ShapeUtil::MakeTupleShape(shapes_with_layouts);
-}
-
// Converts StringRef to xla Transpose enum.
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
llvm::StringRef transpose_str) {
@@ -908,28 +874,11 @@
auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
if (!xla_api_version.ok()) return failure();
auto& value_map = *ctx.values;
- if (!op.operand_layouts().hasValue() || !op.result_layouts().hasValue()) {
- value_map[result] = xla::CustomCall(
- ctx.builder, std::string(op.call_target_name()), args,
- xla::TypeToShape(result.getType()), std::string(op.backend_config()),
- op.has_side_effect(), /*output_operand_aliasing=*/{},
- /*literal=*/nullptr,
- /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
- /*api_version=*/*xla_api_version);
- return success();
- }
-
- auto operand_shapes_with_layout = ConvertTypesToShapesWithLayout(
- op.getOperandTypes(), op.operand_layouts().getValue());
- xla::Shape result_shape_with_layout = GetCustomCallResultShapeWithLayout(
- result.getType(), op.result_layouts().getValue());
- value_map[result] = xla::CustomCallWithLayout(
+ value_map[result] = xla::CustomCall(
ctx.builder, std::string(op.call_target_name()), args,
- result_shape_with_layout, operand_shapes_with_layout,
- std::string(op.backend_config()), op.has_side_effect(),
- /*output_operand_aliasing=*/{},
- /*literal=*/nullptr,
- /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
+ xla::TypeToShape(result.getType()), std::string(op.backend_config()),
+ op.has_side_effect(), /*output_operand_aliasing=*/{},
+ /*literal=*/nullptr, /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/*xla_api_version);
return success();
}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir
deleted file mode 100644
index 5f15e59..0000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --print-layouts=true %s | FileCheck %s
-
-// CHECK: HloModule
-func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> {
- // CHECK: ROOT
- // CHECK-SAME: f32[1,2,3]{2,0,1} custom-call
- // CHECK-SAME: operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}}
- %0 = "mhlo.custom_call"(%arg0, %arg1) {
- call_target_name = "foo",
- operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>],
- result_layouts = [dense<[2, 0, 1]> : tensor<3xindex>]
- } : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
- return %0 : tensor<1x2x3xf32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index e32a27a..a6fa067 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -292,49 +292,10 @@
%test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
- // CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
- // CHECK-SAME: api_version = 1 : i32
- // CHECK-SAME: backend_config = "bar"
- // CHECK-SAME: call_target_name = "foo"
- // CHECK-SAME: has_side_effect = true
- // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
+// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {api_version = 1 : i32, backend_config = "bar", call_target_name = "foo", has_side_effect = true, xla_shape = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
}
-// CHECK-LABEL: func private @test_custom_call_layout
-// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>, [[ARG_2:%.*]]: !mhlo.token, [[ARG_3:%.*]]: tensor<i32>) -> tensor<1x2x3xf32>
-%test_custom_call_layout (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
- %arg1 = f32[2,3] parameter(0)
- %arg2 = f32[5,5] parameter(1)
- %arg3 = token[] parameter(2)
- %arg4 = s32[] parameter(3)
- // CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]]) {
- // CHECK-SAME: api_version = 1 : i32
- // CHECK-SAME: backend_config = "bar"
- // CHECK-SAME: call_target_name = "foo"
- // CHECK-SAME: has_side_effect = true
- // CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]
- // CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>]
- // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>, !mhlo.token, tensor<i32>) -> tensor<1x2x3xf32>
- ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2, token[] %arg3, s32[] %arg4), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}, token[], s32[]}
-}
-
-// CHECK-LABEL: func private @test_custom_call_tuple_output
-// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
-%test_custom_call_tuple_output (arg1: f32[2,3], arg2: f32[5,5]) -> (f32[1,2,3], s32[3,7,9]) {
- %arg1 = f32[2,3] parameter(0)
- %arg2 = f32[5,5] parameter(1)
- // CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
- // CHECK-SAME: api_version = 1 : i32
- // CHECK-SAME: backend_config = "bar"
- // CHECK-SAME: call_target_name = "foo"
- // CHECK-SAME: has_side_effect = true
- // CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>]
- // CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>, dense<[2, 0, 1]> : tensor<3xindex>]
- // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
- ROOT %custom-call = (f32[1,2,3]{0,2,1}, s32[3,7,9]{2,0,1}) custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}}
-}
-
// CHECK-LABEL: func private @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
index 7cb48fb..657784e 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --with-layouts=true --print-layouts=true %s | FileCheck %s
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --with-layouts=true %s | FileCheck %s
// Checks exporting layouts
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 576dc4a..67f14d0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -6036,15 +6036,15 @@
// using a string.
if (!op._XlaSharding().hasValue()) return failure();
- mlir::ArrayAttr empty_layout_attr;
auto custom_call = rewriter.create<mhlo::CustomCallOp>(
op.getLoc(), op.getType(), op.input(),
- /*call_target_name=*/"Sharding",
- /*has_side_effect=*/false,
- /*backend_config=*/"",
- /*api_version=*/mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL,
- /*operand_layouts=*/empty_layout_attr,
- /*result_layouts=*/empty_layout_attr);
+ /*call_target_name=*/rewriter.getStringAttr("Sharding"),
+ /*has_side_effect=*/rewriter.getBoolAttr(false),
+ /*backend_config=*/rewriter.getStringAttr(""),
+ /*api_version=*/
+ mhlo::CustomCallApiVersionAttr::get(
+ rewriter.getContext(),
+ mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL));
custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
rewriter.replaceOp(op, custom_call.getResult(0));
diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
index 61a1a23..246abe7 100644
--- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
+++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
@@ -55,11 +55,6 @@
llvm::cl::init(false));
// NOLINTNEXTLINE
-llvm::cl::opt<bool> print_layouts(
- "print-layouts", llvm::cl::desc("Print layouts in the generated HLO text"),
- llvm::cl::init(false));
-
-// NOLINTNEXTLINE
llvm::cl::opt<bool> via_builder(
"via-builder", llvm::cl::desc("Translate MHLO->XLA HLO via XLA Builder"),
llvm::cl::init(false));
@@ -165,7 +160,7 @@
HloModule* hlo_module = statusOrHloModule.ValueOrDie().get();
output << hlo_module->ToString(
- HloPrintOptions().set_include_layout_in_shapes(print_layouts));
+ HloPrintOptions().set_include_layout_in_shapes(with_layouts));
// Output alias information as comments in the HLO text.
hlo_module->input_output_alias_config().ForEachAlias(