Breaks TF Serving build and IREE OSS build.

PiperOrigin-RevId: 410414971
Change-Id: I5cf27ba27291f564395c34c1246cfbc3b4aa505e
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index b25285d..f1f2592 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1490,20 +1490,6 @@
     `call_target_name` should be short as it may be used in labels.
     `backend_config` can encode arbitrarily large amounts of information.
 
-    A custom call can also have layout constraints on operands and results which
-    can be specified as optional `operand_layouts` and `result_layouts`
-    attributes. The layout attribute is an array of rank-1 index tensors and the
-    i-th layout attribute specifies the layout for i-th operand/result.
-
-    The `operand_layouts` & `result_layouts` attributes can be specified under
-    the following constraints:
-    1) Either both `operand_layouts` and `result_layouts` are specified or none.
-    2) None of the operands are of tuple type.
-    3) None of the results are of tuple type except the common case of single
-       tuple result packing non-tuple values is allowed. In this case the i-th
-       `result_layouts` attribute specifies the layout of i-th element in the
-       result tuple.
-
     See https://www.tensorflow.org/xla/operation_semantics#customcall.
   }];
   let arguments = (ins
@@ -1515,9 +1501,7 @@
     // the status-returning API.
     DefaultValuedAttr<HLO_CustomCallApiVersionAttr,
                       "CustomCallApiVersion::API_VERSION_ORIGINAL">:
-                      $api_version,
-    OptionalAttr<HLO_ArrayOfLayoutAttr>:$operand_layouts,
-    OptionalAttr<HLO_ArrayOfLayoutAttr>:$result_layouts
+                      $api_version
   );
   let results = (outs Variadic<HLO_TensorOrTokenOrTuple>);
   let hasCustomHLOConverter = 1;
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
index 9da77ef..0b0d517 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td
@@ -111,11 +111,6 @@
   let convertFromStorage = IndexElementsAttr.convertFromStorage;
 }
 
-// An array of layout (1D tensor) attributes.
-def HLO_ArrayOfLayoutAttr : TypedArrayAttrBase<HLO_LayoutAttr,
-    "Array of layout (1D tensor of index type) attributes">;
-
-
 //===----------------------------------------------------------------------===//
 // Common convolution attributes
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index c96fe77..fac5b2c 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -249,101 +249,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// CustomCallOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult Verify(CustomCallOp op) {
-  // If both operand and result layout attributes are not specified then nothing
-  // to verify.
-  if (!op.operand_layouts().hasValue() && !op.result_layouts().hasValue())
-    return success();
-
-  // Layout constraints for either both operands & results or none should be
-  // specified.
-  if (op.operand_layouts().hasValue() != op.result_layouts().hasValue())
-    return op.emitOpError() << "Layout attributes should be specified for "
-                               "either both operands and results or none.";
-
-  // Helper function to verify types and the corresponding layouts.
-  auto verify_types_and_layouts =
-      [&op](TypeRange types, mlir::ArrayAttr layouts,
-            const std::string& value_name) -> LogicalResult {
-    if (types.size() != layouts.size())
-      return op.emitOpError()
-             << "Number of " << value_name << "s must match the number of "
-             << value_name << " layouts, " << types.size()
-             << " != " << layouts.size();
-
-    for (const auto& indexed_type_and_layout :
-         llvm::enumerate(llvm::zip(types, layouts))) {
-      // Get index for more descriptive error message.
-      auto index = indexed_type_and_layout.index();
-
-      auto type = std::get<0>(indexed_type_and_layout.value());
-      auto layout = std::get<1>(indexed_type_and_layout.value())
-                        .cast<DenseIntElementsAttr>();
-
-      if (type.isa<TupleType>())
-        return op.emitOpError() << "Tuple types are not fully supported with "
-                                   "layout constraints yet";
-      auto tensor_type = type.dyn_cast<TensorType>();
-
-      // For non-tensor types such as !mhlo.token, the layout should be empty.
-      if (!tensor_type) {
-        if (layout.empty()) continue;
-        return op.emitOpError()
-               << "Only tensor types can have non-empty layout: " << value_name
-               << " #" << index << " of type " << type << " has layout "
-               << layout;
-      }
-
-      // For unranked tensors, we cannot verify the compatibility with layout
-      // any further.
-      if (!tensor_type.hasRank()) continue;
-
-      // Layout must be a permutation of [0, N) where N is the rank of the
-      // tensor type.
-      auto range = llvm::iota_range<unsigned>(0, tensor_type.getRank(),
-                                              /*Inclusive=*/false);
-      if (tensor_type.getRank() != layout.size() ||
-          !std::is_permutation(range.begin(), range.end(), layout.begin()))
-        return op.emitOpError()
-               << "incorrect layout " << layout << " for type " << type
-               << ", layout must be a permutation of [0, "
-               << tensor_type.getRank() << ")";
-    }
-    return success();
-  };
-
-  // At this point both `operand_layouts` and `result_layouts` are defined.
-  ArrayAttr operand_layouts = op.operand_layouts().getValue();
-  ArrayAttr result_layouts = op.result_layouts().getValue();
-
-  // Full support for layouts for arbitrary nesting of tuples is not
-  // supported yet.
-  //
-  // If result does not have any tuples, then i-th element of `result_layouts`
-  // specifies the layout constraints on i-th result.
-  //
-  // For the common case of a single tuple result packing non-tuple values, the
-  // i-th element of `result_layouts` specifies layout for i-th element of the
-  // result tuple.
-  TypeRange result_types;
-  if (op->getNumResults() == 1 && op->getResult(0).getType().isa<TupleType>())
-    result_types = op->getResult(0).getType().cast<TupleType>().getTypes();
-  else
-    result_types = op->getResultTypes();
-
-  // Verify that operands and operand layouts match.
-  if (failed(verify_types_and_layouts(op->getOperandTypes(), operand_layouts,
-                                      "operand")))
-    return failure();
-
-  // Verify that results and result layouts match.
-  return verify_types_and_layouts(result_types, result_layouts, "result");
-}
-
-//===----------------------------------------------------------------------===//
 // DotGeneralOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
index 3588699..4d7e091 100644
--- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir
@@ -2002,125 +2002,6 @@
 
 // -----
 
-// CHECK: func @custom_call_multiple_inputs_outputs_with_layout
-func @custom_call_multiple_inputs_outputs_with_layout(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<f32> {
-  %0:3 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
-    result_layouts = [dense<> : tensor<0xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
-  } : (tensor<2xf32>, !mhlo.token) -> (tensor<f32>, tensor<2xf32>, !mhlo.token)
-  return %0#0 : tensor<f32>
-}
-
-// -----
-
-// CHECK: func @custom_call_tuple_output_with_layout
-func @custom_call_tuple_output_with_layout(%x: tensor<2xf32>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
-  %0 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
-    result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
-  } : (tensor<2xf32>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-  return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_only_operand_layout_constraints(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<2xf32> {
-  // expected-error@+1 {{Layout attributes should be specified for either both operands and results or none}}
-  %0:3 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
-  } : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
-  return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_operands(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<2xf32> {
-  // expected-error@+1 {{Number of operands must match the number of operand layouts, 2 != 1}}
-  %0:3 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>],
-    result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
-  } : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
-  return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_results() -> tensor<2xf32> {
-  // expected-error@+1 {{Number of results must match the number of result layouts, 3 != 2}}
-  %0:3 = "mhlo.custom_call"() {
-    call_target_name = "foo",
-    operand_layouts = [],
-    result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>]
-  } : () -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token)
-  return %0#0 : tensor<2xf32>
-}
-
-// -----
-
-func @custom_call_layout_mismatch_num_results_tuple(%x: tensor<2xf32>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
-  // expected-error@+1 {{Number of results must match the number of result layouts, 3 != 2}}
-  %0 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
-    result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>]
-  } : (tensor<2xf32>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-  return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_tuple_operand_input(%x: tuple<tensor<2xf32>>, %token: !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token> {
-  // expected-error@+1 {{Tuple types are not fully supported with layout constraints yet}}
-  %0 = "mhlo.custom_call"(%x, %token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>],
-    result_layouts = [dense<[0]> : tensor<1xindex>, dense<[0]> : tensor<1xindex>, dense<> : tensor<0xindex>]
-  } : (tuple<tensor<2xf32>>, !mhlo.token) -> tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-  return %0 : tuple<tensor<2xf32>, tensor<2xf32>, !mhlo.token>
-}
-
-// -----
-
-func @custom_call_token_with_layout(%token: !mhlo.token) {
-  // expected-error@+1 {{Only tensor types can have non-empty layout: operand #0 of type '!mhlo.token' has layout dense<[0, 1]> : tensor<2xindex>}}
-  "mhlo.custom_call"(%token) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0, 1]> : tensor<2xindex>],
-    result_layouts = []
-  } : (!mhlo.token) -> ()
-  return
-}
-
-// -----
-
-func @custom_call_mismatch_tensor_and_layout_rank(%arg: tensor<2x3xf32>) {
-  // expected-error@+1 {{incorrect layout dense<[0, 1, 2]> : tensor<3xindex> for type 'tensor<2x3xf32>', layout must be a permutation of [0, 2)}}
-  "mhlo.custom_call"(%arg) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0, 1, 2]> : tensor<3xindex>],
-    result_layouts = []
-  } : (tensor<2x3xf32>) -> ()
-  return
-}
-
-// -----
-
-func @custom_call_mismatch_tensor_and_layout_permutation(%arg: tensor<1x2x3xf32>) {
-  // expected-error@+1 {{incorrect layout dense<[0, 1, 3]> : tensor<3xindex> for type 'tensor<1x2x3xf32>', layout must be a permutation of [0, 3)}}
-  "mhlo.custom_call"(%arg) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0, 1, 3]> : tensor<3xindex>],
-    result_layouts = []
-  } : (tensor<1x2x3xf32>) -> ()
-  return
-}
-
-// -----
-
 // CHECK: func @reduce_window
 func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
   %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index d39b628..6780ec4 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -597,7 +597,6 @@
     hdrs = ["attribute_importer.h"],
     deps = [
         "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc
index 4fa5d5f..356fd9a 100644
--- a/tensorflow/compiler/mlir/xla/attribute_importer.cc
+++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc
@@ -151,47 +151,4 @@
   }
 }
 
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
-    const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder) {
-  std::vector<mlir::Attribute> layouts;
-  for (auto& shape_and_layout : shapes_with_layouts) {
-    if (shape_and_layout.IsTuple())
-      return tensorflow::errors::Unimplemented(
-          "Layout support for nested tuples is not implemented.");
-    const xla::Layout& xla_layout = shape_and_layout.layout();
-
-    // XLA can have invalid layout for certain values (such as token types).
-    // These are imported as empty layout in MHLO.
-    if (xla_layout.format() == xla::Format::INVALID_FORMAT) {
-      layouts.push_back(builder->getIndexTensorAttr({}));
-      continue;
-    }
-
-    // Only a subset of layout specification in XLA is supported in MHLO
-    // currently. The layout has to be dense, and only specify the order of
-    // dimensions. Sparse, tiled layout or non-default memory space fields
-    // cannot be expressed in MHLO layout yet.
-    if (xla_layout.format() != xla::Format::DENSE)
-      return tensorflow::errors::Unimplemented("Unexpected layout format");
-    if (!xla_layout.tiles().empty())
-      return tensorflow::errors::Unimplemented(
-          "Tiled layout is not supported yet");
-    if (xla_layout.memory_space() != xla::Layout::kDefaultMemorySpace)
-      return tensorflow::errors::Unimplemented(
-          "Layout support for non-default memory space is not yet implemented");
-
-    llvm::SmallVector<int64_t> layout;
-    for (int64_t dim_index : xla_layout.minor_to_major())
-      layout.push_back(dim_index);
-    layouts.push_back(builder->getIndexTensorAttr(layout));
-  }
-  return builder->getArrayAttr(layouts);
-}
-
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const Shape shape,
-                                                  mlir::Builder* builder) {
-  if (!shape.IsTuple()) return InvalidArgument("Expected shape to be Tuple");
-  return ExtractLayoutsFromShapes(shape.tuple_shapes(), builder);
-}
-
 }  // namespace xla
diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.h b/tensorflow/compiler/mlir/xla/attribute_importer.h
index 0e6709e..39ace0f 100644
--- a/tensorflow/compiler/mlir/xla/attribute_importer.h
+++ b/tensorflow/compiler/mlir/xla/attribute_importer.h
@@ -20,7 +20,6 @@
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
-#include "tensorflow/compiler/xla/shape.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/types.h"
@@ -54,17 +53,6 @@
 StatusOr<mlir::mhlo::CustomCallApiVersion> ConvertCustomCallApiVersion(
     xla::CustomCallApiVersion api_version);
 
-// Extracts layouts from shapes and converts it into layout attributes (array of
-// rank-1 index tensors). Returns an error if any of the shapes is a tuple.
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromShapes(
-    const absl::Span<const Shape> shapes_with_layouts, mlir::Builder* builder);
-
-// Extracts the layouts of each element from a tuple shape and returns them as
-// an array of rank-1 index tensors. Returns an error in presence of nested
-// tuple shapes.
-StatusOr<mlir::ArrayAttr> ExtractLayoutsFromTuple(const xla::Shape shape,
-                                                  mlir::Builder* builder);
-
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 4749456..56aa436 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -347,27 +347,6 @@
     }
     case HloOpcode::kCustomCall: {
       auto custom_call = Cast<HloCustomCallInstruction>(instruction);
-      if (custom_call->layout_constrained()) {
-        TF_ASSIGN_OR_RETURN(
-            mlir::ArrayAttr operand_layouts,
-            ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
-                                     builder_));
-        attributes.push_back(
-            builder_->getNamedAttr("operand_layouts", operand_layouts));
-        mlir::ArrayAttr result_layouts;
-        if (custom_call->shape().IsTuple()) {
-          TF_ASSIGN_OR_RETURN(
-              result_layouts,
-              ExtractLayoutsFromTuple(custom_call->shape(), builder_));
-        } else {
-          TF_ASSIGN_OR_RETURN(
-              result_layouts,
-              ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
-        }
-        attributes.push_back(
-            builder_->getNamedAttr("result_layouts", result_layouts));
-      }
-
       TF_ASSIGN_OR_RETURN(
           auto mlir_api_version,
           ConvertCustomCallApiVersion(custom_call->api_version()));
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 084ae21..8277265 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -139,20 +139,9 @@
     const Literal* literal, absl::optional<Window> window,
     absl::optional<ConvolutionDimensionNumbers> dnums,
     CustomCallSchedule schedule, CustomCallApiVersion api_version) {
-  mlir::ArrayAttr operand_layouts;
-  mlir::ArrayAttr result_layouts;
-  if (operand_shapes_with_layout.has_value()) {
-    TF_ASSIGN_OR_RETURN(operand_layouts,
-                        ExtractLayoutsFromShapes(
-                            operand_shapes_with_layout.value(), &builder_));
-    if (shape.IsTuple()) {
-      TF_ASSIGN_OR_RETURN(result_layouts,
-                          ExtractLayoutsFromTuple(shape, &builder_));
-    } else {
-      TF_ASSIGN_OR_RETURN(result_layouts,
-                          ExtractLayoutsFromShapes({shape}, &builder_));
-    }
-  }
+  if (operand_shapes_with_layout.has_value())
+    return Unimplemented(
+        "CustomCall doesn't support operands shapes with layout");
   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
                                          shape, builder_));
   TF_ASSIGN_OR_RETURN(auto mlir_api_version,
@@ -173,8 +162,7 @@
       builder_.getStringAttr(opaque),
       /*api_version=*/
       mlir::mhlo::CustomCallApiVersionAttr::get(builder_.getContext(),
-                                                mlir_api_version),
-      operand_layouts, result_layouts);
+                                                mlir_api_version));
   return MakeXlaOp(op.getResult(0));
 }
 
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 9bc19d7..49b034b 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -183,40 +183,6 @@
   return xla::ConvertReplicaGroups(groups).ValueOrDie();
 }
 
-// Converts types and corresponding layouts into xla shapes with layouts.
-static std::vector<xla::Shape> ConvertTypesToShapesWithLayout(
-    mlir::TypeRange value_types, mlir::ArrayAttr layouts) {
-  std::vector<xla::Shape> shapes_with_layout;
-  for (auto type_and_layout : llvm::zip(value_types, layouts)) {
-    mlir::Type type = std::get<0>(type_and_layout);
-    mlir::Attribute layout = std::get<1>(type_and_layout);
-    assert(!type.isa<mlir::TupleType>() &&
-           "Exporting layout for tuples is not implemented yet");
-    shapes_with_layout.emplace_back(xla::TypeToShape(type));
-    auto& shape = shapes_with_layout.back();
-    shape.mutable_layout()->clear_minor_to_major();
-    for (auto l : layout.cast<mlir::DenseIntElementsAttr>()) {
-      shape.mutable_layout()->mutable_minor_to_major()->push_back(
-          l.getSExtValue());
-    }
-  }
-  return shapes_with_layout;
-}
-
-// CustomCallOp result can be of tuple type to pack multiple results into one
-// value. If the custom call result is a tuple, then result layouts represent
-// the layout of each element of the tuple. Nested tuples are currently not
-// supported for export.
-static xla::Shape GetCustomCallResultShapeWithLayout(mlir::Type type,
-                                                     mlir::ArrayAttr layouts) {
-  auto tuple_type = type.dyn_cast<mlir::TupleType>();
-  if (!tuple_type) return ConvertTypesToShapesWithLayout({type}, layouts)[0];
-
-  std::vector<xla::Shape> shapes_with_layouts =
-      ConvertTypesToShapesWithLayout(tuple_type.getTypes(), layouts);
-  return xla::ShapeUtil::MakeTupleShape(shapes_with_layouts);
-}
-
 // Converts StringRef to xla Transpose enum.
 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
     llvm::StringRef transpose_str) {
@@ -908,28 +874,11 @@
   auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
   if (!xla_api_version.ok()) return failure();
   auto& value_map = *ctx.values;
-  if (!op.operand_layouts().hasValue() || !op.result_layouts().hasValue()) {
-    value_map[result] = xla::CustomCall(
-        ctx.builder, std::string(op.call_target_name()), args,
-        xla::TypeToShape(result.getType()), std::string(op.backend_config()),
-        op.has_side_effect(), /*output_operand_aliasing=*/{},
-        /*literal=*/nullptr,
-        /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
-        /*api_version=*/*xla_api_version);
-    return success();
-  }
-
-  auto operand_shapes_with_layout = ConvertTypesToShapesWithLayout(
-      op.getOperandTypes(), op.operand_layouts().getValue());
-  xla::Shape result_shape_with_layout = GetCustomCallResultShapeWithLayout(
-      result.getType(), op.result_layouts().getValue());
-  value_map[result] = xla::CustomCallWithLayout(
+  value_map[result] = xla::CustomCall(
       ctx.builder, std::string(op.call_target_name()), args,
-      result_shape_with_layout, operand_shapes_with_layout,
-      std::string(op.backend_config()), op.has_side_effect(),
-      /*output_operand_aliasing=*/{},
-      /*literal=*/nullptr,
-      /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
+      xla::TypeToShape(result.getType()), std::string(op.backend_config()),
+      op.has_side_effect(), /*output_operand_aliasing=*/{},
+      /*literal=*/nullptr, /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
       /*api_version=*/*xla_api_version);
   return success();
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir
deleted file mode 100644
index 5f15e59..0000000
--- a/tensorflow/compiler/mlir/xla/tests/translate/export_and_check_layouts.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --print-layouts=true %s | FileCheck %s
-
-// CHECK:  HloModule
-func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> {
-  // CHECK:  ROOT
-  // CHECK-SAME:  f32[1,2,3]{2,0,1} custom-call
-  // CHECK-SAME:  operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}}
-  %0 = "mhlo.custom_call"(%arg0, %arg1) {
-    call_target_name = "foo",
-    operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>],
-    result_layouts = [dense<[2, 0, 1]> : tensor<3xindex>]
-  } : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
-  return %0 : tensor<1x2x3xf32>
-}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index e32a27a..a6fa067 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -292,49 +292,10 @@
 %test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
   %arg1 = f32[2,3] parameter(0)
   %arg2 = f32[5,5] parameter(1)
-  // CHECK:  "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
-  // CHECK-SAME: api_version = 1 : i32
-  // CHECK-SAME: backend_config = "bar"
-  // CHECK-SAME: call_target_name = "foo"
-  // CHECK-SAME: has_side_effect = true
-  // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
+// CHECK:  "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {api_version = 1 : i32, backend_config = "bar", call_target_name = "foo", has_side_effect = true, xla_shape = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
 }
 
-// CHECK-LABEL:  func private @test_custom_call_layout
-// CHECK-SAME:  [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>, [[ARG_2:%.*]]: !mhlo.token, [[ARG_3:%.*]]: tensor<i32>) -> tensor<1x2x3xf32>
-%test_custom_call_layout (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
-  %arg1 = f32[2,3] parameter(0)
-  %arg2 = f32[5,5] parameter(1)
-  %arg3 = token[] parameter(2)
-  %arg4 = s32[] parameter(3)
-  // CHECK:  "mhlo.custom_call"([[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]]) {
-  // CHECK-SAME: api_version = 1 : i32
-  // CHECK-SAME: backend_config = "bar"
-  // CHECK-SAME: call_target_name = "foo"
-  // CHECK-SAME: has_side_effect = true
-  // CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]
-  // CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>]
-  // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>, !mhlo.token, tensor<i32>) -> tensor<1x2x3xf32>
-  ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2, token[] %arg3, s32[] %arg4), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}, token[], s32[]}
-}
-
-// CHECK-LABEL:  func private @test_custom_call_tuple_output
-// CHECK-SAME:  [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
-%test_custom_call_tuple_output (arg1: f32[2,3], arg2: f32[5,5]) -> (f32[1,2,3], s32[3,7,9]) {
-  %arg1 = f32[2,3] parameter(0)
-  %arg2 = f32[5,5] parameter(1)
-  // CHECK:  "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
-  // CHECK-SAME: api_version = 1 : i32
-  // CHECK-SAME: backend_config = "bar"
-  // CHECK-SAME: call_target_name = "foo"
-  // CHECK-SAME: has_side_effect = true
-  // CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>]
-  // CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>, dense<[2, 0, 1]> : tensor<3xindex>]
-  // CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
-  ROOT %custom-call = (f32[1,2,3]{0,2,1}, s32[3,7,9]{2,0,1}) custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}}
-}
-
 // CHECK-LABEL:  func private @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
 %test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
   %Arg_0.1 = f32[4] parameter(0)
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
index 7cb48fb..657784e 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --with-layouts=true --print-layouts=true %s | FileCheck %s
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text --with-layouts=true %s | FileCheck %s
 
 // Checks exporting layouts
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 576dc4a..67f14d0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -6036,15 +6036,15 @@
     // using a string.
     if (!op._XlaSharding().hasValue()) return failure();
 
-    mlir::ArrayAttr empty_layout_attr;
     auto custom_call = rewriter.create<mhlo::CustomCallOp>(
         op.getLoc(), op.getType(), op.input(),
-        /*call_target_name=*/"Sharding",
-        /*has_side_effect=*/false,
-        /*backend_config=*/"",
-        /*api_version=*/mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL,
-        /*operand_layouts=*/empty_layout_attr,
-        /*result_layouts=*/empty_layout_attr);
+        /*call_target_name=*/rewriter.getStringAttr("Sharding"),
+        /*has_side_effect=*/rewriter.getBoolAttr(false),
+        /*backend_config=*/rewriter.getStringAttr(""),
+        /*api_version=*/
+        mhlo::CustomCallApiVersionAttr::get(
+            rewriter.getContext(),
+            mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL));
     custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
     rewriter.replaceOp(op, custom_call.getResult(0));
 
diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
index 61a1a23..246abe7 100644
--- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
+++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_registration.cc
@@ -55,11 +55,6 @@
     llvm::cl::init(false));
 
 // NOLINTNEXTLINE
-llvm::cl::opt<bool> print_layouts(
-    "print-layouts", llvm::cl::desc("Print layouts in the generated HLO text"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
 llvm::cl::opt<bool> via_builder(
     "via-builder", llvm::cl::desc("Translate MHLO->XLA HLO via XLA Builder"),
     llvm::cl::init(false));
@@ -165,7 +160,7 @@
   HloModule* hlo_module = statusOrHloModule.ValueOrDie().get();
 
   output << hlo_module->ToString(
-      HloPrintOptions().set_include_layout_in_shapes(print_layouts));
+      HloPrintOptions().set_include_layout_in_shapes(with_layouts));
 
   // Output alias information as comments in the HLO text.
   hlo_module->input_output_alias_config().ForEachAlias(