| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" |
| |
| #include <memory> |
| #include <string> |
| |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/MemoryBuffer.h" |
| #include "llvm/Support/SMLoc.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Matchers.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/UseDefLists.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassManager.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
| #include "tensorflow/compiler/mlir/utils/name_utils.h" |
| #include "tensorflow/compiler/mlir/xla/attribute_exporter.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" |
| #include "tensorflow/compiler/mlir/xla/type_to_shape.h" |
| #include "tensorflow/compiler/tf2xla/layout_util.h" |
| #include "tensorflow/compiler/tf2xla/shape_util.h" |
| #include "tensorflow/compiler/xla/client/lib/matrix.h" |
| #include "tensorflow/compiler/xla/client/lib/quantize.h" |
| #include "tensorflow/compiler/xla/client/lib/slicing.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/comparison_util.h" |
| #include "tensorflow/compiler/xla/literal_util.h" |
| #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| using ::int64_t; |
| using ::stream_executor::port::StatusOr; |
| using ::tensorflow::int16; |
| using ::tensorflow::int32; |
| using ::tensorflow::int8; |
| using ::tensorflow::uint16; |
| using ::tensorflow::uint32; |
| using ::tensorflow::uint64; |
| using ::tensorflow::uint8; |
| |
| constexpr char kShapeIndicesAttr[] = "shape_indices"; |
| constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices"; |
| constexpr char kShardingAttr[] = "mhlo.sharding"; |
| constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; |
| constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas"; |
| |
| // Array attribute. Same shape as infeed result, but contains a |
| // minor_to_major array for every tensor. |
| constexpr char kLayoutAttr[] = "layout"; |
| constexpr char kDefaultLayoutAttrName[] = "xla_shape"; |
| |
| // Passes through everything except for unique_ptr, on which it calls get(). |
| // This exists to allow the generated code to call XLA functions that take a raw |
| // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv |
| // as a pointer and there is otherwise no way to avoid a memory leak. |
| template <typename T> |
| T Unwrap(T t) { |
| return t; |
| } |
| |
| template <typename T> |
| T* Unwrap(const std::unique_ptr<T>& t) { |
| return t.get(); |
| } |
| |
| static mlir::LogicalResult GetXlaOp( |
| mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map, |
| xla::XlaOp* result, mlir::Operation* op) { |
| auto iter = val_map.find(val); |
| if (iter == val_map.end()) { |
| return op->emitOpError( |
| "requires all operands to be defined in the parent region for export"); |
| } |
| *result = iter->second; |
| return mlir::success(); |
| } |
| |
| // Convert APInt into an int. |
| // TODO(hpucha): This should be consolidated into a general place. |
| static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } |
| |
| static uint32_t Convertuint32_t(uint32_t i) { return i; } |
| static uint64_t Convertuint64_t(uint64_t i) { return i; } |
| |
| // Convert APFloat to double. |
| static double ConvertAPFloat(llvm::APFloat value) { |
| const auto& semantics = value.getSemantics(); |
| bool losesInfo = false; |
| if (&semantics != &llvm::APFloat::IEEEdouble()) |
| value.convert(llvm::APFloat::IEEEdouble(), |
| llvm::APFloat::rmNearestTiesToEven, &losesInfo); |
| return value.convertToDouble(); |
| } |
| |
| static inline bool Convertbool(bool value) { return value; } |
| |
| static absl::string_view ConvertStringRef(mlir::StringRef value) { |
| return {value.data(), value.size()}; |
| } |
| |
| static std::vector<int64_t> ConvertDenseIntAttr( |
| mlir::DenseIntElementsAttr attr) { |
| auto values = attr.getValues<int64_t>(); |
| return {values.begin(), values.end()}; |
| } |
| |
| static std::vector<int64_t> ConvertDenseIntAttr( |
| llvm::Optional<mlir::DenseIntElementsAttr> attr) { |
| if (!attr) return {}; |
| return ConvertDenseIntAttr(*attr); |
| } |
| |
| // Converts the broadcast_dimensions attribute into a vector of dimension |
| // numbers (empty if the attribute is absent). |
| static std::vector<int64_t> Convert_broadcast_dimensions( |
| llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) { |
| if (!broadcast_dimensions.hasValue()) return {}; |
| |
| return ConvertDenseIntAttr(*broadcast_dimensions); |
| } |
| |
| // Converts StringRef to xla FftType enum |
| static xla::FftType Convert_fft_type(mlir::mhlo::FftType fft_type) { |
| xla::FftType fft_type_enum; |
| // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse' |
| // call below should never return false. |
| if (!FftType_Parse(std::string(mlir::mhlo::stringifyFftType(fft_type)), |
| &fft_type_enum)) |
| return xla::FftType::FFT; |
| return fft_type_enum; |
| } |
| |
| static std::vector<std::pair<int64_t, int64_t>> Convert_padding( |
| llvm::Optional<mlir::DenseIntElementsAttr> padding) { |
| return xla::ConvertNx2Attribute(padding).ValueOrDie(); |
| } |
| |
| static std::vector<std::pair<int64_t, int64_t>> Convert_source_target_pairs( |
| llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) { |
| return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie(); |
| } |
| |
| static std::vector<xla::ReplicaGroup> Convert_replica_groups( |
| mlir::DenseIntElementsAttr groups) { |
| return xla::ConvertReplicaGroups(groups).ValueOrDie(); |
| } |
| |
| // Converts types and corresponding layouts into xla shapes with layouts. |
| static std::vector<xla::Shape> ConvertTypesToShapesWithLayout( |
| mlir::TypeRange value_types, mlir::ArrayAttr layouts) { |
| std::vector<xla::Shape> shapes_with_layout; |
| for (auto type_and_layout : llvm::zip(value_types, layouts)) { |
| mlir::Type type = std::get<0>(type_and_layout); |
| mlir::Attribute layout = std::get<1>(type_and_layout); |
| assert(!type.isa<mlir::TupleType>() && |
| "Exporting layout for tuples is not implemented yet"); |
| shapes_with_layout.emplace_back(xla::TypeToShape(type)); |
| auto& shape = shapes_with_layout.back(); |
| shape.mutable_layout()->clear_minor_to_major(); |
| for (auto l : layout.cast<mlir::DenseIntElementsAttr>()) { |
| shape.mutable_layout()->mutable_minor_to_major()->push_back( |
| l.getSExtValue()); |
| } |
| } |
| return shapes_with_layout; |
| } |
| |
| // CustomCallOp result can be of tuple type to pack multiple results into one |
| // value. If the custom call result is a tuple, then result layouts represent |
| // the layout of each element of the tuple. Nested tuples are currently not |
| // supported for export. |
| static xla::Shape GetCustomCallResultShapeWithLayout(mlir::Type type, |
| mlir::ArrayAttr layouts) { |
| auto tuple_type = type.dyn_cast<mlir::TupleType>(); |
| if (!tuple_type) return ConvertTypesToShapesWithLayout({type}, layouts)[0]; |
| |
| std::vector<xla::Shape> shapes_with_layouts = |
| ConvertTypesToShapesWithLayout(tuple_type.getTypes(), layouts); |
| return xla::ShapeUtil::MakeTupleShape(shapes_with_layouts); |
| } |
| |
| // Converts StringRef to xla Transpose enum. |
| static xla::TriangularSolveOptions::Transpose Convert_transpose_a( |
| mlir::mhlo::Transpose transpose) { |
| return xla::ConvertTranspose(mlir::mhlo::stringifyTranspose(transpose)) |
| .ValueOrDie(); |
| } |
| |
| static xla::Layout ExtractLayout( |
| mlir::Operation* op, int rank, |
| llvm::StringRef attr_name = kDefaultLayoutAttrName) { |
| if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(attr_name)) { |
| llvm::SmallVector<int64_t, 4> minor_to_major; |
| DCHECK_EQ(rank, attr.size()); |
| minor_to_major.reserve(attr.size()); |
| for (const llvm::APInt& i : attr) { |
| minor_to_major.push_back(i.getZExtValue()); |
| } |
| return xla::LayoutUtil::MakeLayout(minor_to_major); |
| } |
| return xla::LayoutUtil::MakeDescendingLayout(rank); |
| } |
| |
| static xla::Shape ExtractXlaShape(mlir::Operation* op) { |
| if (auto attr = op->getAttrOfType<mlir::StringAttr>(kDefaultLayoutAttrName)) { |
| return *xla::ParseShape( |
| absl::string_view(attr.getValue().data(), attr.getValue().size())); |
| } else { |
| std::vector<xla::Shape> subshapes; |
| for (mlir::Value result : op->getResults()) { |
| subshapes.push_back(xla::TypeToShape(result.getType())); |
| } |
| if (subshapes.size() > 1) { |
| return xla::ShapeUtil::MakeTupleShape(subshapes); |
| } |
| return subshapes[0]; |
| } |
| } |
| |
| #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \ |
| static std::vector<int64_t> Convert_##attribute( \ |
| llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \ |
| return ConvertDenseIntAttr(attribute); \ |
| } |
| |
| I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes); |
| I64_ELEMENTS_ATTR_TO_VECTOR(permutation); |
| I64_ELEMENTS_ATTR_TO_VECTOR(start_indices); |
| I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices); |
| I64_ELEMENTS_ATTR_TO_VECTOR(strides); |
| I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes); |
| I64_ELEMENTS_ATTR_TO_VECTOR(fft_length); |
| I64_ELEMENTS_ATTR_TO_VECTOR(dimensions); |
| I64_ELEMENTS_ATTR_TO_VECTOR(window_strides); |
| I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation); |
| I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation); |
| |
| #undef I64_ELEMENTS_ATTR_TO_VECTOR |
| |
| #define BOOL_ELEMENTS_ATTR_TO_VECTOR(attribute) \ |
| static std::vector<bool> Convert_##attribute( \ |
| llvm::Optional<mlir::DenseElementsAttr> attribute) { \ |
| if (!attribute) return {}; \ |
| auto values = attribute->getValues<bool>(); \ |
| return {values.begin(), values.end()}; \ |
| } |
| |
| BOOL_ELEMENTS_ATTR_TO_VECTOR(window_reversal); |
| |
| #undef BOOL_ELEMENTS_ATTR_TO_VECTOR |
| |
| static std::vector<int64_t> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) { |
| return {values.begin(), values.end()}; |
| } |
| |
| // Converts the precision config array of strings attribute into the |
| // corresponding XLA proto. All the strings are assumed to be valid names of the |
| // Precision enum. This should have been checked in the op verify method. |
| static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config( |
| llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) { |
| if (!optional_precision_config_attr.hasValue()) return nullptr; |
| |
| auto precision_config = std::make_unique<xla::PrecisionConfig>(); |
| for (auto attr : optional_precision_config_attr.getValue()) { |
| xla::PrecisionConfig::Precision p; |
| auto operand_precision = |
| mlir::mhlo::stringifyPrecision( |
| attr.cast<mlir::mhlo::PrecisionAttr>().getValue()) |
| .str(); |
| // TODO(jpienaar): Update this to ensure this is captured by verify. |
| if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) { |
| precision_config->add_operand_precision(p); |
| } else { |
| auto* context = attr.getContext(); |
| mlir::emitError(mlir::UnknownLoc::get(context)) |
| << "unexpected operand precision " << operand_precision; |
| return nullptr; |
| } |
| } |
| |
| return precision_config; |
| } |
| |
| static xla::DotDimensionNumbers Convert_dot_dimension_numbers( |
| mlir::mhlo::DotDimensionNumbersAttr dot_dimension_numbers_attr) { |
| xla::DotDimensionNumbers dot_dimension_numbers; |
| |
| auto rhs_contracting_dimensions = |
| dot_dimension_numbers_attr.getRhsContractingDimensions(); |
| auto lhs_contracting_dimensions = |
| dot_dimension_numbers_attr.getLhsContractingDimensions(); |
| auto rhs_batch_dimensions = |
| dot_dimension_numbers_attr.getRhsBatchingDimensions(); |
| auto lhs_batch_dimensions = |
| dot_dimension_numbers_attr.getLhsBatchingDimensions(); |
| |
| for (const auto& val : rhs_contracting_dimensions) { |
| dot_dimension_numbers.add_rhs_contracting_dimensions(val); |
| } |
| for (const auto& val : lhs_contracting_dimensions) { |
| dot_dimension_numbers.add_lhs_contracting_dimensions(val); |
| } |
| |
| for (const auto& val : rhs_batch_dimensions) { |
| dot_dimension_numbers.add_rhs_batch_dimensions(val); |
| } |
| |
| for (const auto& val : lhs_batch_dimensions) { |
| dot_dimension_numbers.add_lhs_batch_dimensions(val); |
| } |
| |
| return dot_dimension_numbers; |
| } |
| |
| static xla::ConvolutionDimensionNumbers Convert_dimension_numbers( |
| mlir::mhlo::ConvDimensionNumbersAttr input) { |
| return xla::ConvertConvDimensionNumbers(input); |
| } |
| |
| xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr) { |
| xla::ChannelHandle channel_handle; |
| channel_handle.set_handle(attr.getHandle()); |
| channel_handle.set_type( |
| static_cast<xla::ChannelHandle::ChannelType>(attr.getType())); |
| return channel_handle; |
| } |
| |
| std::optional<xla::ChannelHandle> Convert_channel_handle( |
| llvm::Optional<mlir::mhlo::ChannelHandleAttr> attr) { |
| if (!attr.hasValue()) return std::nullopt; |
| return Convert_channel_handle(attr.getValue()); |
| } |
| |
| // Converts the comparison_direction string attribute into the XLA enum. The |
| // string is assumed to correspond to exactly one of the allowed strings |
| // representing the enum. This should have been checked in the op verify method. |
| static xla::ComparisonDirection Convert_comparison_direction( |
| llvm::StringRef comparison_direction_string) { |
| return xla::StringToComparisonDirection(comparison_direction_string.str()) |
| .ValueOrDie(); |
| } |
| |
| static xla::GatherDimensionNumbers Convert_dimension_numbers( |
| mlir::mhlo::GatherDimensionNumbersAttr input) { |
| xla::GatherDimensionNumbers output; |
| |
| auto offset_dims = input.getOffsetDims(); |
| std::copy(offset_dims.begin(), offset_dims.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_offset_dims())); |
| |
| auto collapsed_slice_dims = input.getCollapsedSliceDims(); |
| std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_collapsed_slice_dims())); |
| |
| auto start_index_map = input.getStartIndexMap(); |
| std::copy(start_index_map.begin(), start_index_map.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_start_index_map())); |
| |
| output.set_index_vector_dim(input.getIndexVectorDim()); |
| return output; |
| } |
| |
| static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( |
| mlir::mhlo::ScatterDimensionNumbersAttr input) { |
| xla::ScatterDimensionNumbers output; |
| |
| auto update_window_dims = input.getUpdateWindowDims(); |
| std::copy(update_window_dims.begin(), update_window_dims.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_update_window_dims())); |
| |
| auto inserted_window_dims = input.getInsertedWindowDims(); |
| std::copy(inserted_window_dims.begin(), inserted_window_dims.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_inserted_window_dims())); |
| |
| auto scatter_dims_to_operand_dims = input.getScatterDimsToOperandDims(); |
| std::copy(scatter_dims_to_operand_dims.begin(), |
| scatter_dims_to_operand_dims.end(), |
| tensorflow::protobuf::RepeatedFieldBackInserter( |
| output.mutable_scatter_dims_to_operand_dims())); |
| |
| output.set_index_vector_dim(input.getIndexVectorDim()); |
| return output; |
| } |
| |
| // Extracts sharding from attribute string. |
| static std::optional<xla::OpSharding> CreateOpShardingFromStringRef( |
| llvm::StringRef sharding) { |
| xla::OpSharding sharding_proto; |
| if (!sharding_proto.ParseFromString(sharding.str())) return std::nullopt; |
| return sharding_proto; |
| } |
| |
| // Returns an OpSharding proto from the "sharding" attribute of the op. If the |
| // op doesn't have a sharding attribute or the sharding attribute is invalid, |
| // returns std::nullopt. |
| static std::optional<xla::OpSharding> CreateOpShardingFromAttribute( |
| mlir::Operation* op) { |
| auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr); |
| if (!sharding) return std::nullopt; |
| return CreateOpShardingFromStringRef(sharding.getValue()); |
| } |
| |
| // Returns a FrontendAttributes proto from the "frontend_attributes" attribute |
| // of the op. An empty FrontendAttributes proto is returned if an op does not |
| // have frontend attributes. |
| static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute( |
| mlir::Operation* op) { |
| xla::FrontendAttributes frontend_attributes; |
| auto frontend_attributes_dict = |
| op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr); |
| |
| if (!frontend_attributes_dict) return frontend_attributes; |
| |
| for (const auto& attr : frontend_attributes_dict) |
| if (auto value_str_attr = attr.getValue().dyn_cast<mlir::StringAttr>()) |
| frontend_attributes.mutable_map()->insert( |
| {attr.getName().str(), value_str_attr.getValue().str()}); |
| |
| return frontend_attributes; |
| } |
| |
| // Returns a OpMetadata proto based on the location of the op. If the location |
| // is unknown, an empty proto is returned. `op_name` are populated with the op |
| // location (converted). FileLineColLoc locations are populated by taking the |
| // file name and line number, and populating `source_file` and `source_line` |
| // respectively. |
| static xla::OpMetadata CreateOpMetadataFromLocation( |
| mlir::Operation* op, mlir::MlirToHloConversionOptions options) { |
| xla::OpMetadata metadata; |
| mlir::Location loc = op->getLoc(); |
| if (loc.isa<mlir::UnknownLoc>()) return metadata; |
| |
| std::string name = mlir::GetNameFromLoc(loc); |
| if (options.legalize_node_names) { |
| mlir::LegalizeNodeName(name); |
| } |
| metadata.set_op_name(name); |
| std::string op_type = mlir::GetOpTypeFromLoc(loc); |
| mlir::LegalizeNodeName(op_type); |
| metadata.set_op_type(op_type); |
| |
| if (auto name_loc = op->getLoc().dyn_cast<mlir::NameLoc>()) { |
| loc = name_loc.getChildLoc(); |
| if (loc.isa<mlir::UnknownLoc>()) return metadata; |
| } |
| |
| if (auto file_line_col_loc = loc.dyn_cast<mlir::FileLineColLoc>()) { |
| metadata.set_source_file(file_line_col_loc.getFilename().str()); |
| metadata.set_source_line(file_line_col_loc.getLine()); |
| } |
| |
| return metadata; |
| } |
| |
| // Checks if all shardings are set. |
| static bool AllOptionalShardingsAreSet( |
| llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) { |
| return llvm::all_of(shardings, |
| [](const std::optional<xla::OpSharding>& sharding) { |
| return sharding.has_value(); |
| }); |
| } |
| |
| // Extracts argument and result shardings from function. |
| static void ExtractShardingsFromFunction( |
| mlir::func::FuncOp function, |
| llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* arg_shardings, |
| llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* ret_shardings) { |
| arg_shardings->resize(function.getNumArguments(), |
| std::optional<xla::OpSharding>()); |
| for (int i = 0, end = function.getNumArguments(); i < end; ++i) |
| if (auto sharding = |
| function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) |
| (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); |
| |
| ret_shardings->resize(function.getNumResults(), |
| std::optional<xla::OpSharding>()); |
| for (int i = 0, end = function.getNumResults(); i < end; ++i) |
| if (auto sharding = |
| function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr)) |
| (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); |
| } |
| |
| namespace mlir { |
| namespace { |
| class ConvertToHloModule { |
| public: |
| using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>; |
| using FunctionLoweringMap = |
| llvm::DenseMap<mlir::func::FuncOp, xla::XlaComputation>; |
| |
| // If use_tuple_args is true, then the entry function's arguments are |
| // converted to a tuple and passed as a single parameter. |
| // Similarly, if return tuple is true, then the entry function's return values |
| // are converted to a tuple even when there is only a single return value. |
| // Multiple return values are always converted to a tuple and returned as a |
| // single value. |
| explicit ConvertToHloModule( |
| mlir::ModuleOp module, xla::XlaBuilder& module_builder, |
| bool use_tuple_args, bool return_tuple, |
| tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns |
| shape_determination_fns, |
| MlirToHloConversionOptions options) |
| : module_(module), |
| module_builder_(module_builder), |
| use_tuple_args_(use_tuple_args), |
| return_tuple_(return_tuple), |
| shape_determination_fns_(shape_determination_fns), |
| options_(options) {} |
| |
| // Perform the lowering to XLA. This function returns failure if an error was |
| // encountered. |
| // |
| // TODO(hinsu): Check for dynamic shapes and exit instead of crashing. |
| LogicalResult Run() { |
| auto main = module_.lookupSymbol<mlir::func::FuncOp>("main"); |
| if (!main) |
| return module_.emitError( |
| "conversion requires module with `main` function"); |
| |
| for (auto func : module_.getOps<func::FuncOp>()) { |
| if (func.empty()) continue; |
| if (failed(RunOnFunction(func))) return failure(); |
| } |
| return success(); |
| } |
| |
| // Lower a specific function to HLO. |
| LogicalResult RunOnFunction(mlir::func::FuncOp f); |
| |
| // Lower a `mlir::Region` to a `XlaComputation` |
| LogicalResult LowerRegionAsComputation( |
| mlir::Region* region, xla::XlaComputation* func, |
| llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands = |
| llvm::None, |
| bool ensure_single_arg = false); |
| |
| // Lower a single `Block` to a `XlaComputation` |
| LogicalResult LowerBasicBlockAsFunction( |
| Block* block, xla::XlaBuilder* builder, bool is_entry_function, |
| bool ensure_single_arg, |
| const std::vector<bool>& entry_args_same_across_replicas, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings, |
| xla::XlaComputation* result, |
| llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands = |
| llvm::None); |
| |
| ::xla::HloModuleProto ConsumeMainProto() { |
| auto main = module_.lookupSymbol<mlir::func::FuncOp>("main"); |
| // This is an invariant check as Run returns failure if there is no main |
| // function and so the main proto shouldn't be consumed in that case. |
| CHECK(main) << "requires module to have main function"; // Crash Ok. |
| return lowered_computation_[main].proto(); |
| } |
| |
| // Lower function call to HLO call instruction |
| LogicalResult LowerFunctionCall( |
| mlir::func::CallOp call_op, xla::XlaBuilder* builder, |
| ConvertToHloModule::ValueLoweringMap* value_lowering); |
| |
| // Look up a symbol with the specified name, returning null if no such name |
| // exists. |
| func::FuncOp LookUpSymbol(FlatSymbolRefAttr symbol) { |
| return module_.lookupSymbol<mlir::func::FuncOp>(symbol); |
| } |
| |
| // Get Reference to lowered XLA computation for a function. |
| xla::XlaComputation& GetLoweredComputation(func::FuncOp func) { |
| return lowered_computation_[func]; |
| } |
| |
| LogicalResult Lower( |
| mlir::Operation* inst, bool is_entry_function, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings, |
| xla::XlaBuilder* builder, |
| ConvertToHloModule::ValueLoweringMap* value_lowering, |
| xla::XlaOp* return_value); |
| |
| const MlirToHloConversionOptions& GetOptions() const { return options_; } |
| |
| private: |
| LogicalResult SetEntryTupleShapesAndLeafReplication( |
| Block* block, const std::vector<bool>& entry_args_same_across_replicas, |
| llvm::SmallVectorImpl<xla::Shape>* arg_shapes, |
| std::vector<bool>* leaf_replication); |
| |
| LogicalResult SetEntryTupleShardings( |
| Block* block, xla::XlaBuilder* builder, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings, |
| llvm::SmallVectorImpl<xla::Shape>* arg_shapes); |
| |
| // The module being lowered. |
| mlir::ModuleOp module_; |
| |
| // The top-level XlaBuilder. |
| xla::XlaBuilder& module_builder_; |
| |
| // Map between function and lowered computation. |
| FunctionLoweringMap lowered_computation_; |
| |
| // Whether the entry function should take a single tuple as input. |
| bool use_tuple_args_; |
| |
| // Whether to always return a tuple. |
| bool return_tuple_; |
| |
| // Shape determination functions to determine entry function argument and |
| // result shapes. |
| tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns |
| shape_determination_fns_; |
| |
| // Unique suffix to give to the name of the next lowered region. |
| size_t region_id_ = 0; |
| |
| MlirToHloConversionOptions options_; |
| }; |
| |
| } // namespace |
| } // namespace mlir |
| |
| namespace { |
| |
| struct OpLoweringContext { |
| llvm::DenseMap<mlir::Value, xla::XlaOp>* values; |
| mlir::ConvertToHloModule* converter; |
| xla::XlaBuilder* builder; |
| }; |
| |
| mlir::LogicalResult GetTuple(mlir::Operation* op, |
| mlir::Operation::operand_range values, |
| OpLoweringContext ctx, |
| llvm::SmallVectorImpl<xla::XlaOp>& results) { |
| results.reserve(values.size()); |
| for (mlir::Value value : values) { |
| if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op))) |
| return mlir::failure(); |
| } |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult GetXlaOps(mlir::Operation* op, |
| llvm::ArrayRef<mlir::Value> values, |
| OpLoweringContext ctx, |
| llvm::SmallVectorImpl<xla::XlaOp>& results) { |
| results.reserve(values.size()); |
| for (mlir::Value value : values) { |
| if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op))) |
| return mlir::failure(); |
| } |
| return mlir::success(); |
| } |
| |
| } // namespace |
| |
| namespace mlir { |
| namespace mhlo { |
| namespace { |
| |
| LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) { |
| // This op has no expression in the legacy export format. It can be expanded |
| // to a sequence of operations if needed in the future, but would feed into |
| // ops creating unsupported dynamic shapes. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(CstrReshapableOp, OpLoweringContext) { |
| // This op has no expression in the legacy export format. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp token; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| auto operand_shape = ctx.builder->GetShape(operand).value(); |
| value_map[op] = xla::internal::XlaBuilderFriend::BuildAddDependency( |
| ctx.builder, operand, token, operand_shape); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| TensorType operand_type = op.operand().getType().cast<TensorType>(); |
| TensorType result_type = op.getType(); |
| if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) |
| return failure(); |
| auto all_gather_dim = op.all_gather_dim(); |
| int64_t shard_count = result_type.getDimSize(all_gather_dim) / |
| operand_type.getDimSize(all_gather_dim); |
| value_map[op] = xla::AllGather(operand, all_gather_dim, shard_count, |
| Convert_replica_groups(op.replica_groups()), |
| Convert_channel_handle(op.channel_handle())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation computation; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(), |
| &computation))) { |
| return failure(); |
| } |
| |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| |
| value_map[op] = xla::AllReduce(operand, computation, |
| Convert_replica_groups(op.replica_groups()), |
| Convert_channel_handle(op.channel_handle())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| TensorType operand_type = op.operand().getType().cast<TensorType>(); |
| TensorType result_type = op.getType(); |
| if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) |
| return failure(); |
| auto scatter_dim = op.scatter_dimension(); |
| int64_t shard_count = operand_type.getDimSize(scatter_dim) / |
| result_type.getDimSize(scatter_dim); |
| |
| xla::XlaComputation computation; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(), |
| &computation))) { |
| return failure(); |
| } |
| |
| value_map[op] = |
| xla::ReduceScatter(operand, computation, scatter_dim, shard_count, |
| Convert_replica_groups(op.replica_groups()), |
| Convert_channel_handle(op.channel_handle())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| |
| value_map[op] = xla::BitcastConvertType( |
| operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { |
| auto type = op.getType().dyn_cast<RankedTensorType>(); |
| if (!type) return failure(); |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| |
| value_map[op] = |
| BroadcastInDim(operand, Convert_ArrayRef(type.getShape()), |
| Convert_broadcast_dimensions(op.broadcast_dimensions())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp lhs, rhs; |
| if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); |
| if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); |
| xla::PrimitiveType preferred_element_type = |
| xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); |
| value_map[op] = xla::Dot( |
| lhs, rhs, Unwrap(Convert_precision_config(op.precision_config())), |
| preferred_element_type); |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp lhs, rhs; |
| if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); |
| if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); |
| xla::PrimitiveType preferred_element_type = |
| xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); |
| value_map[op] = xla::DotGeneral( |
| lhs, rhs, Convert_dot_dimension_numbers(op.dot_dimension_numbers()), |
| Unwrap(Convert_precision_config(op.precision_config())), |
| preferred_element_type); |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(DomainOp op, OpLoweringContext ctx) { |
| auto& valueMap = *ctx.values; |
| |
| xla::Shape shape = xla::TypeToShape(op.getResult().getType()); |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure(); |
| |
| auto entry = CreateOpShardingFromStringRef(op.entry_metadata()); |
| if (!entry) return failure(); |
| auto exit = CreateOpShardingFromStringRef(op.exit_metadata()); |
| if (!exit) return failure(); |
| |
| valueMap[op] = xla::internal::XlaBuilderFriend::BuildDomain( |
| ctx.builder, operand, *exit, *entry, shape); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { |
| // This op has no expression in the legacy export format. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) { |
| // This op has no expression in the legacy export format. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { |
| // This op has no expression in the legacy export format. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { |
| xla::XlaComputation true_branch; |
| xla::XlaComputation false_branch; |
| auto& value_map = *ctx.values; |
| |
| // mhlo.IfOp does not have any operands or blocks-arguments. The computation |
| // inside the region-blocks use implicit captures of values defined above. |
| // In order to create the xla parameters for functions corresponding to |
| // IfOp regions, we need to infer the a region-block's arguments, using all |
| // the values used in the region but defined above. Note that in case there |
| // are zero implicit capture for a region, we use an empty tuple as the xla |
| // parameter. |
| // |
| // Note that the implicit values used in true and false branch regions might |
| // be different and, as a result, the xla parameters for the corresponding |
| // regions could have different shapes. |
| llvm::SetVector<mlir::Value> implicit_true_operand_set, |
| implicit_false_operand_set; |
| getUsedValuesDefinedAbove(op.true_branch(), op.true_branch(), |
| implicit_true_operand_set); |
| getUsedValuesDefinedAbove(op.false_branch(), op.false_branch(), |
| implicit_false_operand_set); |
| |
| llvm::SmallVector<mlir::Value> implicit_true_operands( |
| implicit_true_operand_set.begin(), implicit_true_operand_set.end()); |
| llvm::SmallVector<mlir::Value> implicit_false_operands( |
| implicit_false_operand_set.begin(), implicit_false_operand_set.end()); |
| |
| // Create xla parameters for functions corresponding to ifOp regions using the |
| // implicit captures operands. Also export the instructions within those |
| // regions. |
| if (failed(ctx.converter->LowerRegionAsComputation( |
| &op.true_branch(), &true_branch, |
| llvm::makeArrayRef(implicit_true_operands), |
| /*ensure_single_arg*/ true)) || |
| failed(ctx.converter->LowerRegionAsComputation( |
| &op.false_branch(), &false_branch, |
| llvm::makeArrayRef(implicit_false_operands), |
| /*ensure_single_arg*/ true))) { |
| return failure(); |
| } |
| |
| // Create the Xla pred argument. |
| xla::XlaOp pred; |
| if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure(); |
| |
| // Create the true branch Xla argument. |
| llvm::SmallVector<xla::XlaOp> true_args; |
| if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args))) |
| return failure(); |
| xla::XlaOp true_arg = |
| true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args); |
| |
| // Create the false branch Xla argument. |
| llvm::SmallVector<xla::XlaOp> false_args; |
| if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args))) |
| return failure(); |
| xla::XlaOp false_arg = |
| false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args); |
| |
| // Create XLA Conditional op. |
| auto ifop = |
| xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch); |
| |
| // mhlo.IfOp have multiple returns, untuple all the results of XLA's. |
| if (op.getNumResults() == 1) { |
| value_map[op.getResult(0)] = ifop; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| value_map[item.value()] = xla::GetTupleElement(ifop, item.index()); |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { |
| llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values; |
| // OperandRange operands = op.branch_operands(); |
| MutableArrayRef<Region> branches = op.branches(); |
| llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size()); |
| std::vector<xla::XlaComputation> computations(branches.size()); |
| std::vector<xla::XlaComputation*> computations_p(branches.size()); |
| |
| // mhlo.CaseOp does not have any operands or blocks-arguments. The computation |
| // inside the region-blocks use implicit captures of values defined above. |
| // In order to create the xla parameters for functions corresponding to |
| // CaseOp regions, we need to infer the a region-block's arguments, using all |
| // the values used in the region but defined above. Note that in case there |
| // are zero implicit captures for a region, we use an empty tuple as the xla |
| // parameter. |
| // |
| // Note that the implicit values used in the regions might |
| // be different and, as a result, the xla parameters for the corresponding |
| // regions could have different shapes. |
| for (unsigned i = 0; i < branches.size(); ++i) { |
| llvm::SetVector<mlir::Value> implicit_operand_set; |
| getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set); |
| llvm::SmallVector<mlir::Value> implicit_operands( |
| implicit_operand_set.begin(), implicit_operand_set.end()); |
| |
| // Create the branches[i]'s Xla argument. |
| llvm::SmallVector<xla::XlaOp> args; |
| if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure(); |
| branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args); |
| |
| // Create xla parameters for functions corresponding to region branches[i] |
| // using the implicit captures operands. Also export the instructions within |
| // that region. |
| computations_p[i] = &computations[i]; |
| if (failed(ctx.converter->LowerRegionAsComputation( |
| &branches[i], computations_p[i], |
| llvm::makeArrayRef(implicit_operands), |
| /*ensure_single_arg*/ true))) |
| return failure(); |
| } |
| |
| xla::XlaOp index; |
| if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure(); |
| |
| xla::XlaOp caseop = xla::Conditional(index, computations_p, branch_operands); |
| |
| // mhlo.CaseOp have multiple returns, untuple all the results of XLA's. |
| if (op.getNumResults() == 1) { |
| value_map[op.getResult(0)] = caseop; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| value_map[item.value()] = xla::GetTupleElement(caseop, item.index()); |
| } |
| } |
| return success(); |
| } |
| |
| // Specialize CompareOp export to set broadcast_dimensions argument. |
| mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op, |
| OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp lhs, rhs; |
| if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); |
| if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); |
| auto dir = Convert_comparison_direction( |
| mlir::mhlo::stringifyComparisonDirection(op.comparison_direction())); |
| auto type_attr = op.compare_typeAttr(); |
| |
| xla::XlaOp xla_result; |
| if (type_attr && type_attr.getValue() != mlir::mhlo::ComparisonType::NOTYPE) { |
| auto type = xla::StringToComparisonType( |
| stringifyComparisonType(type_attr.getValue()).str()) |
| .ValueOrDie(); |
| xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type); |
| } else { |
| xla_result = xla::Compare(lhs, rhs, dir); |
| } |
| value_map[op] = xla_result; |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp lhs, rhs; |
| if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); |
| if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); |
| xla::PrimitiveType preferred_element_type = |
| xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())); |
| xla::XlaOp xla_result = xla::ConvGeneralDilated( |
| lhs, rhs, Convert_window_strides(op.window_strides()), |
| Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()), |
| Convert_rhs_dilation(op.rhs_dilation()), |
| xla::ConvertConvDimensionNumbers(op.dimension_numbers()), |
| Convertuint64_t(op.feature_group_count()), |
| Convertuint64_t(op.batch_group_count()), |
| Unwrap(Convert_precision_config(op.precision_config())), |
| preferred_element_type, Convert_window_reversal(op.window_reversal())); |
| value_map[op] = xla_result; |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| |
| value_map[op] = xla::ConvertElementType( |
| operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { |
| if (op.getNumResults() != 1) |
| return op.emitOpError() << "with multiple results cannot be exported"; |
| |
| if (op.called_computations().size() > 1) |
| return op.emitOpError() |
| << "cannot export with more than one called computations"; |
| |
| // Custom call can be exported either with called computation or with layout |
| // attributes. The XlaBuilder API does not allow both. |
| if (!op.called_computations().empty() && op.operand_layouts() && |
| op.result_layouts()) { |
| return op.emitOpError() << "cannot export if both called computation and " |
| "layouts are specified"; |
| } |
| |
| Value result = op.getResult(0); |
| llvm::SmallVector<xla::XlaOp> args; |
| if (failed(GetTuple(op, op.operands(), ctx, args))) return failure(); |
| auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version()); |
| if (!xla_api_version.ok()) return failure(); |
| auto& value_map = *ctx.values; |
| |
| if (op.called_computations().size() == 1) { |
| mlir::func::FuncOp callee = ctx.converter->LookUpSymbol( |
| op.called_computations()[0].cast<FlatSymbolRefAttr>()); |
| if (failed(ctx.converter->RunOnFunction(callee))) return failure(); |
| xla::XlaComputation& computation = |
| ctx.converter->GetLoweredComputation(callee); |
| value_map[result] = xla::CustomCallWithComputation( |
| ctx.builder, std::string(op.call_target_name()), args, computation, |
| xla::TypeToShape(result.getType()), std::string(op.backend_config()), |
| op.has_side_effect(), |
| /*output_operand_aliasing=*/{}, |
| /*literal=*/nullptr, |
| /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, |
| /*api_version=*/*xla_api_version); |
| return success(); |
| } |
| |
| if (op.operand_layouts() && op.result_layouts()) { |
| auto operand_shapes_with_layout = ConvertTypesToShapesWithLayout( |
| op.getOperandTypes(), op.operand_layouts().getValue()); |
| xla::Shape result_shape_with_layout = GetCustomCallResultShapeWithLayout( |
| result.getType(), op.result_layouts().getValue()); |
| value_map[result] = xla::CustomCallWithLayout( |
| ctx.builder, std::string(op.call_target_name()), args, |
| result_shape_with_layout, operand_shapes_with_layout, |
| std::string(op.backend_config()), op.has_side_effect(), |
| /*output_operand_aliasing=*/{}, |
| /*literal=*/nullptr, |
| /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, |
| /*api_version=*/*xla_api_version); |
| return success(); |
| } |
| |
| value_map[result] = xla::CustomCall( |
| ctx.builder, std::string(op.call_target_name()), args, |
| xla::TypeToShape(result.getType()), std::string(op.backend_config()), |
| op.has_side_effect(), /*output_operand_aliasing=*/{}, |
| /*literal=*/nullptr, |
| /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE, |
| /*api_version=*/*xla_api_version); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp token; |
| if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); |
| |
| // mhlo.infeed produces multiple results. The shape argument expected by the |
| // xla client API is a tuple type with two element-types: |
| // data_type : A tuple containing all the mhlo.infeedOp result types except |
| // the token type. |
| // token_type : The last result type of mhlo.infeedOp. |
| auto result_types = op.getResultTypes(); |
| auto num_results = op.getNumResults(); |
| |
| xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]); |
| std::vector<xla::Shape> subshapes; |
| for (const auto& item : llvm::enumerate(result_types)) { |
| if (item.index() == num_results - 1) break; |
| subshapes.push_back(xla::TypeToShape(item.value())); |
| } |
| |
| xla::Shape data_shape = xla::ShapeUtil::MakeTupleShape(subshapes); |
| auto xla_result = |
| xla::InfeedWithToken(token, data_shape, std::string(op.infeed_config())); |
| ctx.builder->ClearSharding(); |
| |
| if (!subshapes.empty()) { |
| auto data_tuple_element = xla::GetTupleElement(xla_result, 0); |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| if (item.index() == num_results - 1) break; |
| value_map[item.value()] = |
| xla::GetTupleElement(data_tuple_element, item.index()); |
| } |
| } |
| |
| value_map[op.getResult(num_results - 1)] = |
| xla::GetTupleElement(xla_result, 1); |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()), |
| op.iota_dimension()); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation computation; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(), |
| &computation))) { |
| return failure(); |
| } |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure(); |
| value_map[op] = xla::Map(ctx.builder, operands, computation, |
| Convert_dimensions(op.dimensions())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure(); |
| |
| xla::XlaOp operand = Tuple(ctx.builder, operands); |
| |
| std::vector<xla::Shape> subshapes; |
| for (auto operand : op.operands()) |
| subshapes.push_back(xla::TypeToShape(operand.getType())); |
| |
| xla::Shape shape_with_layout = xla::ShapeUtil::MakeTupleShape(subshapes); |
| |
| xla::XlaOp token; |
| if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); |
| |
| value_map[op] = xla::OutfeedWithToken(operand, token, shape_with_layout, |
| std::string(op.outfeed_config())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(PartitionIdOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::Shape shape = xla::TypeToShape(op.getResult().getType()); |
| value_map[op] = |
| xla::internal::XlaBuilderFriend::BuildPartitionId(ctx.builder, shape); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::PaddingConfig padding_config; |
| auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low()); |
| auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high()); |
| auto interior_padding = ConvertDenseIntAttr(op.interior_padding()); |
| for (int64_t i = 0, end = edge_padding_low.size(); i < end; ++i) { |
| auto* dims = padding_config.add_dimensions(); |
| dims->set_edge_padding_low(edge_padding_low[i]); |
| dims->set_edge_padding_high(edge_padding_high[i]); |
| dims->set_interior_padding(interior_padding[i]); |
| } |
| xla::XlaOp operand, padding_value; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op))) |
| return failure(); |
| |
| value_map[op] = xla::Pad(operand, padding_value, padding_config); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| |
| xla::XlaOp token; |
| if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); |
| |
| // mhlo.recvOp produces multiple results. The shape argument expected by the |
| // xla client API is a tuple type with two element-types: |
| // data_type : A tuple containing all the mhlo.RecvOp result types except |
| // the token type. |
| // token_type : The last result type of mhlo.recvOp. |
| auto result_types = op.getResultTypes(); |
| auto num_results = op.getNumResults(); |
| |
| xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]); |
| std::vector<xla::Shape> subshapes; |
| for (const auto& item : llvm::enumerate(result_types)) { |
| if (item.index() == num_results - 1) break; |
| subshapes.push_back(xla::TypeToShape(item.value())); |
| } |
| |
| xla::Shape data_shape; |
| if (subshapes.size() == 1) |
| data_shape = subshapes[0]; |
| else |
| data_shape = xla::ShapeUtil::MakeTupleShape(subshapes); |
| |
| xla::XlaOp xla_result; |
| if (op.is_host_transfer()) { |
| xla_result = xla::RecvFromHost(token, data_shape, |
| Convert_channel_handle(op.channel_handle())); |
| } else { |
| xla_result = xla::RecvWithToken( |
| token, data_shape, Convert_channel_handle(op.channel_handle())); |
| } |
| |
| auto data_tuple_element = xla::GetTupleElement(xla_result, 0); |
| if (subshapes.size() == 1) { |
| value_map[op.getResult(0)] = data_tuple_element; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| if (item.index() == num_results - 1) break; |
| value_map[item.value()] = |
| xla::GetTupleElement(data_tuple_element, item.index()); |
| } |
| } |
| |
| value_map[op.getResult(num_results - 1)] = |
| xla::GetTupleElement(xla_result, 1); |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation body; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) { |
| return failure(); |
| } |
| llvm::SmallVector<xla::XlaOp> operands, init_values; |
| if (failed(GetTuple(op, op.operands(), ctx, operands)) || |
| failed(GetTuple(op, op.init_values(), ctx, init_values))) { |
| return failure(); |
| } |
| xla::XlaOp result = |
| xla::Reduce(ctx.builder, operands, init_values, body, |
| Convert_broadcast_dimensions(op.dimensions())); |
| if (op.getNumResults() == 1) { |
| value_map[op.getResult(0)] = result; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| value_map[item.value()] = xla::GetTupleElement(result, item.index()); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation body; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) { |
| return failure(); |
| } |
| llvm::SmallVector<xla::XlaOp> operands, init_values; |
| if (failed(GetTuple(op, op.operands(), ctx, operands)) || |
| failed(GetTuple(op, op.init_values(), ctx, init_values))) { |
| return failure(); |
| } |
| |
| xla::XlaOp result = xla::ReduceWindowWithGeneralPadding( |
| operands, init_values, body, ConvertDenseIntAttr(op.window_dimensions()), |
| ConvertDenseIntAttr(op.window_strides()), |
| ConvertDenseIntAttr(op.base_dilations()), |
| ConvertDenseIntAttr(op.window_dilations()), |
| Convert_padding(op.padding())); |
| |
| if (op.getNumResults() == 1) { |
| value_map[op.getResult(0)] = result; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| value_map[item.value()] = xla::GetTupleElement(result, item.index()); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| |
| value_map[op] = |
| xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) { |
| // Failure on purpose because `mhlo::ReturnOp` will be handled by |
| // special purpose logic in `ConvertToHloModule::Lower`. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| auto results = op.getResults(); |
| auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()]; |
| auto xla_result = xla::RngBitGenerator( |
| static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1), |
| xla::TypeToShape(results[1].getType())); |
| |
| for (const auto& item : llvm::enumerate(results)) |
| value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); |
| |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(XlaRngGetAndUpdateStateOp op, OpLoweringContext ctx) { |
| // This op does not exist in the XLA builder interface. |
| (*ctx.values)[op.getResult()] = |
| xla::internal::XlaBuilderFriend::BuildRngGetAndUpdateState( |
| ctx.builder, static_cast<int64_t>(op.delta()), |
| xla::TypeToShape(op.getType())); |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(BatchNormGradOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| auto results = op.getResults(); |
| |
| xla::XlaOp operand, scale, mean, variance, grad_output; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure(); |
| if (failed(GetXlaOp(op.mean(), value_map, &mean, op))) return failure(); |
| if (failed(GetXlaOp(op.variance(), value_map, &variance, op))) |
| return failure(); |
| if (failed(GetXlaOp(op.grad_output(), value_map, &grad_output, op))) |
| return failure(); |
| |
| auto xla_result = |
| xla::BatchNormGrad(operand, scale, mean, variance, grad_output, |
| ConvertAPFloat(op.epsilon()), op.feature_index()); |
| |
| for (const auto& item : llvm::enumerate(results)) |
| value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); |
| |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| auto results = op.getResults(); |
| |
| xla::XlaOp operand, scale, offset; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure(); |
| if (failed(GetXlaOp(op.offset(), value_map, &offset, op))) return failure(); |
| |
| auto xla_result = xla::BatchNormTraining( |
| operand, scale, offset, ConvertAPFloat(op.epsilon()), op.feature_index()); |
| |
| for (const auto& item : llvm::enumerate(results)) |
| value_map[item.value()] = xla::GetTupleElement(xla_result, item.index()); |
| |
| return mlir::success(); |
| } |
| |
| LogicalResult ExportXlaOp(RngOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp a, b; |
| if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure(); |
| if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure(); |
| |
| if (op.rng_distribution() == RngDistribution::UNIFORM) { |
| value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType())); |
| return success(); |
| } else if (op.rng_distribution() == RngDistribution::NORMAL) { |
| value_map[op] = xla::RngNormal(a, b, xla::TypeToShape(op.getType())); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation update_computation; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(), |
| &update_computation))) { |
| return failure(); |
| } |
| xla::ScatterDimensionNumbers dimension_numbers = |
| Convert_scatter_dimension_numbers(op.scatter_dimension_numbers()); |
| |
| llvm::SmallVector<xla::XlaOp> operands; |
| llvm::SmallVector<xla::XlaOp> updates; |
| if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure(); |
| if (failed(GetTuple(op, op.updates(), ctx, updates))) return failure(); |
| |
| xla::XlaOp scatter_indices; |
| if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op))) |
| return failure(); |
| |
| auto scatter_op = xla::Scatter(operands, scatter_indices, updates, |
| update_computation, dimension_numbers, |
| op.indices_are_sorted(), op.unique_indices()); |
| if (op->getNumResults() == 1) { |
| value_map[op.getResult(0)] = scatter_op; |
| return success(); |
| } |
| |
| // mhlo.ScatterOp supports multiple returns, untuple all the results of XLA's. |
| for (const auto& it : llvm::enumerate(op.getResults())) { |
| value_map[it.value()] = xla::GetTupleElement(scatter_op, it.index()); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaComputation select; |
| xla::XlaComputation scatter; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) || |
| failed( |
| ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) { |
| return failure(); |
| } |
| xla::XlaOp operand, source, init_value; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure(); |
| if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op))) |
| return failure(); |
| |
| value_map[op] = xla::SelectAndScatterWithGeneralPadding( |
| operand, select, ConvertDenseIntAttr(op.window_dimensions()), |
| ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()), |
| source, init_value, scatter); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure(); |
| |
| xla::XlaOp operand; |
| if (operands.size() == 1) |
| operand = operands[0]; |
| else |
| operand = Tuple(ctx.builder, operands); |
| |
| xla::XlaOp token; |
| if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); |
| |
| if (op.is_host_transfer()) { |
| value_map[op] = xla::SendToHost( |
| operand, token, operand.builder()->GetShape(operand).value(), |
| Convert_channel_handle(op.channel_handle())); |
| return success(); |
| } |
| value_map[op] = xla::SendWithToken( |
| operand, token, Convert_channel_handle(op.channel_handle())); |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { |
| xla::XlaComputation comparator; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(), |
| &comparator))) |
| return failure(); |
| |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure(); |
| auto sorted = xla::Sort(operands, comparator, op.dimension(), op.is_stable()); |
| |
| auto& value_map = *ctx.values; |
| auto shape_or = sorted.builder()->GetShape(sorted); |
| if (!shape_or.ok()) { |
| return op.emitError(shape_or.status().ToString()); |
| } |
| |
| xla::Shape& shape = shape_or.ValueOrDie(); |
| if (!shape.IsTuple()) { |
| value_map[op.getResult(0)] = sorted; |
| return success(); |
| } |
| |
| // MLIR's sort supports multiple returns, untuple all the results of XLA's. |
| for (const auto& it : llvm::enumerate(op.getResults())) { |
| value_map[it.value()] = xla::GetTupleElement(sorted, it.index()); |
| } |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) { |
| // TODO(atondwal): remove mhlo.trace |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) { |
| // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two |
| // operands. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { |
| xla::XlaComputation condition; |
| xla::XlaComputation body; |
| if (failed(ctx.converter->LowerRegionAsComputation( |
| &op.body(), &body, llvm::None, /*ensure_single_arg*/ true)) || |
| failed(ctx.converter->LowerRegionAsComputation( |
| &op.cond(), &condition, llvm::None, /*ensure_single_arg*/ true))) { |
| return failure(); |
| } |
| |
| // In case MHLO's whileOp has multiple operands, create xla::Tuple, using |
| // those operands, to be used as sole operand of xla::While. |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure(); |
| |
| xla::XlaOp operand = operands[0]; |
| if (operands.size() > 1) operand = Tuple(ctx.builder, operands); |
| |
| auto whileop = xla::While(condition, body, operand); |
| |
| auto& value_map = *ctx.values; |
| auto shape_or = whileop.builder()->GetShape(whileop); |
| if (!shape_or.ok()) { |
| return op.emitError(shape_or.status().ToString()); |
| } |
| |
| xla::Shape& shape = shape_or.ValueOrDie(); |
| if (!shape.IsTuple()) { |
| value_map[op.getResult(0)] = whileop; |
| return success(); |
| } |
| |
| // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's. |
| for (const auto& it : llvm::enumerate(op.getResults())) { |
| value_map[it.value()] = xla::GetTupleElement(whileop, it.index()); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(OptimizationBarrierOp op, OpLoweringContext ctx) { |
| // In case MHLO's OptimizationBarrierOp has multiple operands, |
| // create xla::Tuple, using those operands, to be used as |
| // sole operand of xla::OptimizationBarrier. |
| llvm::SmallVector<xla::XlaOp> operands; |
| if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure(); |
| if (operands.empty()) return success(); |
| |
| auto& value_map = *ctx.values; |
| if (operands.size() == 1) { |
| value_map[op.getResult(0)] = xla::OptimizationBarrier(operands[0]); |
| } else { |
| auto result = xla::OptimizationBarrier(Tuple(ctx.builder, operands)); |
| |
| for (const auto& it : llvm::enumerate(op.getResults())) { |
| value_map[it.value()] = xla::GetTupleElement(result, it.index()); |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { |
| if (!op.fusion_kind()) { |
| op.emitOpError() << "requires fusion kind for HLO translation"; |
| return failure(); |
| } |
| |
| xla::XlaComputation fused_computation; |
| if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(), |
| &fused_computation))) |
| return failure(); |
| |
| auto& values = *ctx.values; |
| llvm::SmallVector<xla::XlaOp, 4> operands; |
| for (auto operand : op.operands()) operands.push_back(values[operand]); |
| |
| auto fusion_kind_string = |
| mlir::mhlo::stringifyFusionKind(op.fusion_kind().getValue()); |
| xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion( |
| ctx.builder, operands, |
| absl::string_view(fusion_kind_string.data(), fusion_kind_string.size()), |
| fused_computation); |
| if (op.getNumResults() == 1) { |
| values[op.getResult(0)] = fusion; |
| } else { |
| for (const auto& item : llvm::enumerate(op.getResults())) { |
| values[item.value()] = xla::GetTupleElement(fusion, item.index()); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) { |
| auto& value_map = *ctx.values; |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); |
| xla::XlaOp bitcast = xla::internal::XlaBuilderFriend::BuildBitcast( |
| ctx.builder, operand, xla::TypeToShape(op.getType())); |
| value_map[op] = bitcast; |
| if (ctx.converter->GetOptions().propagate_bitcast_layouts_to_backend_config) { |
| // Encode the source and result layout of the bitcast into the XLA HLO |
| // backend config as a protobuf. Note that this is a temporary solution |
| // which will go away once XLA:GPU stops falling back to XLA HLO Elemental |
| // IR emitters. |
| xla::HloInstructionProto* bitcast_proto = |
| xla::internal::XlaBuilderFriend::GetInstruction(bitcast); |
| xla::HloInstructionProto* operand_proto = |
| xla::internal::XlaBuilderFriend::GetInstruction(operand); |
| xla::LayoutProto result_layout = |
| ExtractLayout(op, bitcast_proto->shape().dimensions_size(), |
| "result_layout") |
| .ToProto(); |
| xla::LayoutProto source_layout = |
| ExtractLayout(op, operand_proto->shape().dimensions_size(), |
| "source_layout") |
| .ToProto(); |
| xla::gpu::BitcastBackendConfig bitcast_config; |
| *bitcast_config.mutable_source_layout() = source_layout; |
| *bitcast_config.mutable_result_layout() = result_layout; |
| *bitcast_proto->mutable_backend_config() = |
| bitcast_config.SerializeAsString(); |
| } |
| return success(); |
| } |
| |
| LogicalResult ExportXlaOp(RealDynamicSliceOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicPadOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicGatherOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(DynamicConvOp op, OpLoweringContext ctx) { |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(UniformQuantizeOp op, OpLoweringContext ctx) { |
| // Currently, it doesn't have an XLA builder equivalent. |
| // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops. |
| return failure(); |
| } |
| |
| LogicalResult ExportXlaOp(UniformDequantizeOp op, OpLoweringContext ctx) { |
| // Currently, it doesn't have an XLA builder equivalent. |
| // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops. |
| return failure(); |
| } |
| |
| } // namespace |
| } // namespace mhlo |
| } // namespace mlir |
| |
| #include "tensorflow/compiler/mlir/xla/operator_writers.inc" |
| |
| namespace mlir { |
| namespace { |
| |
| StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr, |
| xla::Layout layout) { |
| if (attr.isa<OpaqueElementsAttr>()) |
| return tensorflow::errors::Unimplemented( |
| "Opaque elements attr not supported"); |
| |
| xla::Shape shape = xla::TypeToShape(attr.getType()); |
| |
| #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ |
| case xla_type: { \ |
| xla::Array<cpp_type> source_data(shape.dimensions()); \ |
| source_data.SetValues( \ |
| attr.cast<DenseElementsAttr>().getValues<cpp_type>()); \ |
| return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \ |
| } |
| |
| switch (shape.element_type()) { |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64_t) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half) |
| ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16) |
| default: |
| return tensorflow::errors::Internal(absl::StrCat( |
| "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type()))); |
| } |
| #undef ELEMENTS_ATTR_TO_LITERAL |
| } |
| |
| LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout, |
| xla::ShapeProto* shape) { |
| // In the case of tuples, ShapeProtos can be nested, and so can the mlir |
| // attribute describing the layout. So recurse into the subshapes in both data |
| // structures in parallel. |
| if (shape->element_type() == xla::TUPLE) { |
| auto subshapes = shape->mutable_tuple_shapes(); |
| |
| // 'layout' does not take the token attribute into account, so skip the |
| // corresponding entry from xla shape proto. |
| size_t subshapes_data_size = subshapes->size(); |
| if (!subshapes->empty() && |
| subshapes->Mutable(subshapes->size() - 1)->element_type() == xla::TOKEN) |
| subshapes_data_size = subshapes->size() - 1; |
| |
| if (layout.size() != subshapes_data_size) { |
| op->emitOpError() << "Expected layout of size " << layout.size() |
| << ", but found " << subshapes->size(); |
| return failure(); |
| } |
| for (int i = 0; i < subshapes_data_size; i++) { |
| mlir::Attribute child = layout[i]; |
| if (child.isa<mlir::UnitAttr>()) { |
| // ignore unit attributes, they are used only for tokens. |
| continue; |
| } |
| mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>(); |
| if (!c) { |
| op->emitOpError() << "Type Error: Expected layout array attribute"; |
| return failure(); |
| } |
| if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) { |
| return failure(); |
| } |
| } |
| } else { |
| int rank = shape->dimensions().size(); |
| if (rank) { |
| if (layout.size() != rank) { |
| return failure(); // pass error down |
| } |
| std::vector<int64_t> array(rank); |
| for (int i = 0; i < rank; i++) { |
| mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>(); |
| if (!attr) { |
| op->emitOpError() << "Type Error: Expected layout integer attribute"; |
| return failure(); |
| } |
| array[i] = attr.getInt(); |
| } |
| *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto(); |
| } |
| } |
| return success(); |
| } |
| |
| // Assigns layouts from 'layout' to shape. |
| // The function accepts any of the following shapes |
| // one or more array-shape(s) of infeed data |
| // Tuple(Tuple(zero or more array-shape w.r.t data), token_type) |
| // |
| // 'layout' of the mhlo.InfedOp 'op' is |
| // [zero or more layout for each array-shape w.r.t data] |
| // 'layout_index' indexes into 'layout' accessing a layout corresponding to a |
| // shape. |
| LogicalResult ConvertInfeedtLayout(mlir::Operation* op, |
| const mlir::ArrayAttr& layout, |
| xla::ShapeProto* shape, |
| int64_t layout_index = 0) { |
| if (shape->element_type() != xla::TUPLE) { |
| // Handles following shape: |
| // single array-shape of infeed data |
| mlir::ArrayAttr child_layout = |
| layout[layout_index].dyn_cast<mlir::ArrayAttr>(); |
| if (!child_layout) { |
| op->emitOpError() << "Type Error: Expected layout array attribute"; |
| return failure(); |
| } |
| |
| int rank = shape->dimensions().size(); |
| if (rank) { |
| if (child_layout.size() != rank) { |
| return failure(); // pass error down |
| } |
| std::vector<int64_t> array(rank); |
| for (int i = 0; i < rank; i++) { |
| mlir::IntegerAttr attr = child_layout[i].dyn_cast<mlir::IntegerAttr>(); |
| if (!attr) { |
| op->emitOpError() << "Type Error: Expected layout integer attribute"; |
| return failure(); |
| } |
| array[i] = attr.getInt(); |
| } |
| *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto(); |
| } |
| |
| return success(); |
| } |
| |
| auto subshapes = shape->mutable_tuple_shapes(); |
| auto datashape = subshapes->Mutable(0); |
| |
| if (datashape->element_type() == xla::TUPLE) { |
| // Handles following shapes: |
| // (Tuple(zero or more array-shape w.r.t data), token_type) |
| auto data_subshapes = datashape->mutable_tuple_shapes(); |
| if (layout.size() != data_subshapes->size()) { |
| op->emitOpError() << "Expected " << data_subshapes->size() |
| << " layout attribute(s) for infeed data, but found " |
| << layout.size(); |
| return failure(); |
| } |
| |
| for (int i = 0; i < data_subshapes->size(); i++) { |
| if (failed( |
| ConvertInfeedtLayout(op, layout, data_subshapes->Mutable(i), i))) |
| return failure(); |
| } |
| } else { |
| // Handles following shapes: |
| // array-shapes of two or more infeed data |
| if (layout.size() != subshapes->size()) { |
| op->emitOpError() << "Expected " << subshapes->size() |
| << " layout attribute(s) for infeed data, but found " |
| << layout.size(); |
| return failure(); |
| } |
| |
| for (int i = 0; i < subshapes->size(); i++) { |
| if (failed(ConvertInfeedtLayout(op, layout, subshapes->Mutable(i), i))) |
| return failure(); |
| } |
| } |
| |
| return success(); |
| } |
| |
| // MHLO and XLA HLO disagree on the meaning of addition of `pred` / `i1`, so |
| // there has to be a special case somewhere to account for the difference. To |
| // get the expected behavior of an `AddOp` on `i1`, we have to use `xor`. Since |
| // the majority of the conversion is generated code, we just sidestep it here |
| // for this single case, and inline the code to emit an `xor`. |
| LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst, |
| OpLoweringContext ctx) { |
| auto op = dyn_cast<mlir::mhlo::AddOp>(inst); |
| if (op && op.getResult() |
| .getType() |
| .cast<mlir::TensorType>() |
| .getElementType() |
| .isSignlessInteger(1)) { |
| auto& value_map = *ctx.values; |
| auto result = op.getResult(); |
| xla::XlaOp xla_arg_0; |
| if (failed(GetXlaOp(op.lhs(), value_map, &xla_arg_0, op))) |
| return mlir::failure(); |
| xla::XlaOp xla_arg_1; |
| if (failed(GetXlaOp(op.rhs(), value_map, &xla_arg_1, op))) |
| return mlir::failure(); |
| auto xla_result = xla::Xor(Unwrap(xla_arg_0), Unwrap(xla_arg_1)); |
| value_map[result] = xla_result; |
| return mlir::success(); |
| } |
| |
| return ExportXlaOperator(inst, ctx); |
| } |
| |
| LogicalResult ConvertToHloModule::Lower( |
| mlir::Operation* inst, bool is_entry_function, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings, |
| xla::XlaBuilder* builder, |
| ConvertToHloModule::ValueLoweringMap* value_lowering, |
| xla::XlaOp* return_value) { |
| // Explicitly fail for ops that are not supported for export. |
| if (inst->getDialect() != |
| inst->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>() && |
| !mlir::isa<mlir::func::ConstantOp, mlir::arith::ConstantOp, |
| mlir::func::CallOp, mlir::tensor::CastOp, |
| mlir::func::ReturnOp>(inst)) { |
| inst->emitOpError("unsupported op for export to XLA"); |
| return failure(); |
| } |
| |
| *return_value = xla::XlaOp(); |
| |
| // See MlirToHloConversionOptions for more about layouts. |
| auto propagate_layouts = [this](mlir::Operation* inst, |
| xla::XlaOp xla_op) -> mlir::LogicalResult { |
| if (options_.propagate_layouts) { |
| auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) |
| ->mutable_shape(); |
| // TODO(kramm): merge this with ConvertLayout. |
| *shape = ExtractXlaShape(inst).ToProto(); |
| } |
| |
| return success(); |
| }; |
| |
| if (succeeded( |
| ExportXlaOperatorWrapped(inst, {value_lowering, this, builder}))) { |
| if (inst->getNumResults() == 1) { |
| auto iter = value_lowering->find(inst->getResult(0)); |
| if (iter == value_lowering->end()) { |
| inst->emitOpError( |
| "inst has a result, but it's not found in value_lowering"); |
| return failure(); |
| } |
| if (failed(propagate_layouts(inst, iter->second))) { |
| return failure(); |
| } |
| } |
| // For infeed ops stemming back to InfeedDequeueTuple, respect the |
| // layout attribute, and create the corresponding layout in hlo. |
| if (isa<mhlo::InfeedOp>(inst)) { |
| mlir::ArrayAttr layout = |
| inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr); |
| |
| if (layout) { |
| // We propagate layout to the following three ops: |
| // L1: For each data-result of mhlo.InfeedOp, we find the exported |
| // xla::kGetTupleElement and propagate the layout. |
| // |
| // L2: For the token-result of mhlo.InfeedOp (result at last index), |
| // we extract the xla::kInfeed op using the corresponding |
| // xla::kGetTupleElement and propagate the layout to it. |
| // |
| // L3: In case there are non-zero data-results, there exists an |
| // additional xla::kGetTupleElement accessing a tuple of the |
| // data-results. We need to propagate the layout to that |
| // xla::kGetTupleElement as well. |
| auto num_results = inst->getNumResults(); |
| bool propagate_layout_to_data_tuple = true; |
| for (unsigned i = 0; i < num_results; i++) { |
| auto iter = value_lowering->find(inst->getResult(i)); |
| if (iter == value_lowering->end()) { |
| inst->emitOpError() << "inst's result value at index " << i |
| << " has no match in value_lowering"; |
| return failure(); |
| } |
| auto xla_gte_op = iter->second; |
| xla::HloInstructionProto* get_tuple_element_proto = |
| xla::internal::XlaBuilderFriend::GetInstruction(xla_gte_op); |
| |
| assert(xla::StringToHloOpcode(get_tuple_element_proto->opcode()) |
| .ValueOrDie() == xla::HloOpcode::kGetTupleElement && |
| "The token-result of mhlo.InfeedOp should be mapped to a " |
| "xla::HloOpcode::kGetTupleElement"); |
| |
| if (i == num_results - 1) { |
| // L2 |
| xla::HloInstructionProto* xla_infeed_op_proto = |
| xla::internal::XlaBuilderFriend::GetInstructionByHandle( |
| xla_gte_op.builder(), |
| get_tuple_element_proto->operand_ids(0)); |
| |
| assert(xla::StringToHloOpcode(xla_infeed_op_proto->opcode()) |
| .ValueOrDie() == xla::HloOpcode::kInfeed && |
| "Expected xla::HloOpcode::kInfeed op"); |
| |
| auto* shape = xla_infeed_op_proto->mutable_shape(); |
| if (failed(ConvertInfeedtLayout(inst, layout, shape))) |
| return failure(); |
| |
| } else { |
| // L1 |
| auto* shape = get_tuple_element_proto->mutable_shape(); |
| if (failed(ConvertInfeedtLayout(inst, layout, shape, i))) |
| return failure(); |
| |
| // L3 |
| if (propagate_layout_to_data_tuple) { |
| xla::HloInstructionProto* data_tuple_proto = |
| xla::internal::XlaBuilderFriend::GetInstructionByHandle( |
| xla_gte_op.builder(), |
| get_tuple_element_proto->operand_ids(0)); |
| auto* data_tuple_shape = data_tuple_proto->mutable_shape(); |
| |
| assert(xla::StringToHloOpcode(data_tuple_proto->opcode()) |
| .ValueOrDie() == |
| xla::HloOpcode::kGetTupleElement && |
| "Expected a xla:tupleOp for all the data results."); |
| if (failed(ConvertInfeedtLayout(inst, layout, data_tuple_shape))) |
| return failure(); |
| } |
| propagate_layout_to_data_tuple = false; |
| } |
| } |
| } |
| } |
| return success(); |
| } |
| |
| auto& value_map = *value_lowering; |
| ElementsAttr const_attr; |
| |
| if (auto call_op = dyn_cast<mlir::func::CallOp>(inst)) { |
| return LowerFunctionCall(call_op, builder, &value_map); |
| } |
| |
| if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) { |
| Value operand = op.getOperand(); |
| auto ty = operand.getType().dyn_cast<ShapedType>(); |
| // If this was a cast from a static shaped tensors, then it is a noop for |
| // export to HLO and we can use the operand. |
| if (!ty || !ty.hasStaticShape()) { |
| inst->emitOpError() |
| << "requires static shaped operand for HLO translation"; |
| return failure(); |
| } |
| |
| xla::XlaOp xla_operand; |
| if (failed(GetXlaOp(operand, value_map, &xla_operand, op))) |
| return failure(); |
| value_map[op.getResult()] = xla_operand; |
| if (failed(propagate_layouts(inst, xla_operand))) { |
| return failure(); |
| } |
| return success(); |
| } |
| |
| if (matchPattern(inst, m_Constant(&const_attr))) { |
| if (!inst->getResult(0).getType().isa<ShapedType>()) { |
| return inst->emitError( |
| "expected shaped type during constant mhlo -> hlo translation"); |
| } |
| |
| auto literal_or = |
| CreateArrayLiteralFromAttr(const_attr, ExtractXlaShape(inst).layout()); |
| if (!literal_or.ok()) |
| return inst->emitError(literal_or.status().ToString()); |
| auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); |
| value_map[inst->getResult(0)] = constant; |
| |
| return success(); |
| } |
| |
| if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) { |
| // Construct the return value for the function. If there is a single value |
| // returned, then return it directly, else create a tuple and return. |
| unsigned num_return_values = inst->getNumOperands(); |
| const bool has_ret_shardings = |
| !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings); |
| if ((return_tuple_ && is_entry_function) || num_return_values != 1) { |
| std::vector<xla::XlaOp> returns(num_return_values); |
| for (OpOperand& ret : inst->getOpOperands()) { |
| unsigned index = ret.getOperandNumber(); |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(ret.get(), value_map, &operand, inst))) |
| return failure(); |
| |
| returns[index] = operand; |
| if (!is_entry_function || !has_ret_shardings) continue; |
| |
| xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); |
| StatusOr<xla::XlaOp> reshape = |
| tensorflow::ReshapeWithCorrectRepresentationAndSharding( |
| builder, returns[index], return_shape, shape_determination_fns_, |
| ret_shardings[index], /*fast_mem=*/false); |
| if (!reshape.ok()) |
| return inst->emitError() << reshape.status().error_message(); |
| |
| returns[index] = reshape.ValueOrDie(); |
| } |
| |
| if (has_ret_shardings) { |
| xla::OpSharding sharding; |
| sharding.set_type(xla::OpSharding::TUPLE); |
| for (auto& ret_sharding : ret_shardings) |
| *sharding.add_tuple_shardings() = *ret_sharding; |
| |
| builder->SetSharding(sharding); |
| } |
| |
| *return_value = xla::Tuple(builder, returns); |
| builder->ClearSharding(); |
| } else if (num_return_values == 1) { |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) |
| return failure(); |
| |
| if (has_ret_shardings) { |
| auto tuple = Tuple(builder, {operand}); |
| builder->SetSharding(*ret_shardings[0]); |
| *return_value = GetTupleElement(tuple, 0); |
| builder->ClearSharding(); |
| } else { |
| *return_value = operand; |
| } |
| } |
| |
| return success(); |
| } |
| |
| inst->emitOpError() << "can't be translated to XLA HLO"; |
| return failure(); |
| } |
| |
| LogicalResult ConvertToHloModule::LowerFunctionCall( |
| mlir::func::CallOp call_op, xla::XlaBuilder* builder, |
| ConvertToHloModule::ValueLoweringMap* value_lowering) { |
| auto& value_map = *value_lowering; |
| mlir::func::FuncOp callee = |
| module_.lookupSymbol<mlir::func::FuncOp>(call_op.getCallee()); |
| if (failed(RunOnFunction(callee))) return failure(); |
| std::vector<xla::XlaOp> operands; |
| for (auto operand : call_op.getOperands()) { |
| xla::XlaOp xla_operand; |
| if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op))) |
| return failure(); |
| operands.push_back(xla_operand); |
| } |
| // Each call to xla::Call would insert a copy of the computation to |
| // the HLO. Thus each callsite would have a unique callee in the |
| // exported HLO. HLO syntactically does not require all calls to have unique |
| // callees, but eventually before lowering call graph is "flattened" to |
| // make that true. This is done before lowering because buffer assignment |
| // needs this invariant. |
| xla::XlaOp call_result = |
| xla::Call(builder, lowered_computation_[callee], operands); |
| // Use GetTupleElement for multiple outputs |
| unsigned num_results = call_op.getNumResults(); |
| if (num_results > 1) { |
| for (unsigned i = 0; i != num_results; ++i) { |
| value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i); |
| } |
| } else if (num_results == 1) { |
| value_map[call_op.getResult(0)] = call_result; |
| } |
| return success(); |
| } |
| |
| LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { |
| if (lowered_computation_.count(f)) return success(); |
| if (!llvm::hasSingleElement(f)) { |
| return f.emitError("only single block Function supported"); |
| } |
| |
| // Create a sub-builder if this is not the main function. |
| std::unique_ptr<xla::XlaBuilder> builder_up; |
| bool entry_function = f.getName() == "main"; |
| if (!entry_function) |
| builder_up = module_builder_.CreateSubBuilder(f.getName().str()); |
| auto& builder = entry_function ? module_builder_ : *builder_up; |
| |
| xla::XlaComputation computation; |
| std::vector<bool> entry_args_same_across_replicas; |
| llvm::SmallVector<std::optional<xla::OpSharding>, 4> arg_shardings; |
| llvm::SmallVector<std::optional<xla::OpSharding>, 4> ret_shardings; |
| if (entry_function) { |
| bool any_arg_replicated = false; |
| entry_args_same_across_replicas.reserve(f.getNumArguments()); |
| for (int64_t i = 0; i < f.getNumArguments(); ++i) { |
| auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kReplicationAttr); |
| entry_args_same_across_replicas.push_back(attr != nullptr); |
| any_arg_replicated |= entry_args_same_across_replicas.back(); |
| // Pass the alias info to the builder so that it will build the alias info |
| // into the resulting HloModule. |
| auto aliasing_output = |
| f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output"); |
| if (!aliasing_output) continue; |
| xla::ShapeIndex output_index; |
| if ((return_tuple_ && entry_function) || f.getNumResults() != 1) { |
| output_index = {aliasing_output.getInt()}; |
| } else { |
| if (aliasing_output.getInt() != 0) { |
| return f.emitError( |
| "Aliasing output must be 0 if only one output exists"); |
| } |
| output_index = {}; |
| } |
| if (use_tuple_args_) { |
| builder.SetUpAlias(output_index, /*param_number=*/0, |
| /*param_index=*/{i}); |
| } else { |
| builder.SetUpAlias(output_index, /*param_number=*/i, |
| /*param_index=*/{}); |
| } |
| } |
| // Do not populate this field when nothing is replicated, since empty field |
| // means no replication. This avoids the need for unrelated tests to handle |
| // this field. |
| if (!any_arg_replicated) entry_args_same_across_replicas.clear(); |
| |
| ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); |
| } |
| if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function, |
| false, entry_args_same_across_replicas, |
| arg_shardings, ret_shardings, |
| &computation))) { |
| return failure(); |
| } |
| lowered_computation_[f] = std::move(computation); |
| return success(); |
| } |
| |
| LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication( |
| Block* block, const std::vector<bool>& entry_args_same_across_replicas, |
| llvm::SmallVectorImpl<xla::Shape>* arg_shapes, |
| std::vector<bool>* leaf_replication) { |
| arg_shapes->reserve(block->getNumArguments()); |
| leaf_replication->reserve(block->getNumArguments()); |
| for (BlockArgument& arg : block->getArguments()) { |
| arg_shapes->push_back(xla::TypeToShape(arg.getType())); |
| xla::Shape& arg_shape = arg_shapes->back(); |
| tensorflow::TensorShape arg_tensor_shape; |
| auto status = |
| tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape); |
| if (!status.ok()) |
| return block->getParentOp()->emitError() << status.error_message(); |
| |
| tensorflow::DataType arg_dtype; |
| status = tensorflow::ConvertToDataType(arg.getType(), &arg_dtype); |
| if (!status.ok()) |
| return block->getParentOp()->emitError() << status.error_message(); |
| |
| CHECK(shape_determination_fns_.layout_preference_fn && // Crash OK |
| shape_determination_fns_.shape_representation_fn); |
| auto layout_preference = shape_determination_fns_.layout_preference_fn( |
| arg_tensor_shape, arg_dtype, std::nullopt); |
| auto arg_shape_status = shape_determination_fns_.shape_representation_fn( |
| arg_tensor_shape, arg_dtype, /*use_fast_memory=*/false, |
| layout_preference); |
| if (!arg_shape_status.ok()) |
| return block->getParentOp()->emitError() |
| << arg_shape_status.status().error_message(); |
| |
| arg_shape = std::move(arg_shape_status.ValueOrDie()); |
| |
| if (entry_args_same_across_replicas.empty()) continue; |
| for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i) |
| leaf_replication->push_back( |
| entry_args_same_across_replicas[arg.getArgNumber()]); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ConvertToHloModule::SetEntryTupleShardings( |
| Block* block, xla::XlaBuilder* builder, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings, |
| llvm::SmallVectorImpl<xla::Shape>* arg_shapes) { |
| if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { |
| xla::OpSharding sharding; |
| sharding.set_type(xla::OpSharding::TUPLE); |
| for (const auto& arg_sharding : llvm::enumerate(arg_shardings)) { |
| auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value()); |
| if (!hlo_sharding.ok()) |
| return block->getParentOp()->emitError() |
| << hlo_sharding.status().error_message(); |
| |
| auto status = tensorflow::RewriteLayoutWithShardedShape( |
| hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false, |
| shape_determination_fns_, &(*arg_shapes)[arg_sharding.index()]); |
| if (!status.ok()) |
| return block->getParentOp()->emitError() << status.error_message(); |
| |
| *sharding.add_tuple_shardings() = *arg_sharding.value(); |
| } |
| |
| builder->SetSharding(sharding); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( |
| Block* block, xla::XlaBuilder* builder, bool is_entry_function, |
| bool ensure_single_arg, |
| const std::vector<bool>& entry_args_same_across_replicas, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings, |
| llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings, |
| xla::XlaComputation* result, |
| llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands) { |
| // Mapping from the Value to lowered XlaOp. |
| ValueLoweringMap lowering; |
| |
| // If using tuples as input, then there is only one input parameter that is a |
| // tuple. |
| if (is_entry_function && use_tuple_args_) { |
| llvm::SmallVector<xla::Shape, 4> arg_shapes; |
| std::vector<bool> leaf_replication; |
| if (failed(SetEntryTupleShapesAndLeafReplication( |
| block, entry_args_same_across_replicas, &arg_shapes, |
| &leaf_replication))) |
| return failure(); |
| |
| if (failed( |
| SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes))) |
| return failure(); |
| |
| xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); |
| auto tuple = |
| xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication); |
| builder->ClearSharding(); |
| |
| bool set_tuple_element_sharding = |
| !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings); |
| for (BlockArgument& arg : block->getArguments()) { |
| if (set_tuple_element_sharding) |
| builder->SetSharding(*arg_shardings[arg.getArgNumber()]); |
| lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber()); |
| } |
| builder->ClearSharding(); |
| } else { |
| if (ensure_single_arg) { |
| // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp. |
| llvm::SmallVector<xla::Shape> arg_shapes; |
| |
| auto args_size = block->getNumArguments(); |
| if (implicit_operands) args_size = implicit_operands->size(); |
| |
| arg_shapes.reserve(args_size); |
| if (implicit_operands) { |
| for (auto implicit_operand : *implicit_operands) |
| arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType())); |
| } else { |
| for (BlockArgument& arg : block->getArguments()) |
| arg_shapes.push_back(xla::TypeToShape(arg.getType())); |
| } |
| |
| if (args_size > 1) { |
| auto tuple = xla::Parameter(builder, 0, |
| xla::ShapeUtil::MakeTupleShape(arg_shapes), |
| "arg_tuple"); |
| |
| if (implicit_operands) { |
| int arg_index = 0; |
| for (auto implicit_operand : *implicit_operands) |
| lowering[implicit_operand] = |
| xla::GetTupleElement(tuple, arg_index++); |
| } else { |
| for (BlockArgument& arg : block->getArguments()) |
| lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber()); |
| } |
| } else if (args_size == 1) { |
| if (implicit_operands) { |
| lowering[(*implicit_operands)[0]] = |
| xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); |
| } else { |
| lowering[block->getArgument(0)] = |
| xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); |
| } |
| } else { |
| // Applicable only for IfOp or CaseOp. No implicit operands implies no |
| // xla parameters. In this case, we create an empty tuple as the |
| // block-parameter. |
| xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes), |
| "arg_empty_tuple"); |
| } |
| } else { |
| for (BlockArgument& arg : block->getArguments()) { |
| auto num = arg.getArgNumber(); |
| xla::Shape shape = xla::TypeToShape(arg.getType()); |
| if (!arg_shardings.empty() && arg_shardings[num]) { |
| builder->SetSharding(*arg_shardings[num]); |
| } |
| if (entry_args_same_across_replicas.empty()) { |
| lowering[arg] = |
| xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num)); |
| } else { |
| lowering[arg] = xla::Parameter( |
| builder, num, shape, absl::StrCat("Arg_", num), |
| std::vector<bool>(entry_args_same_across_replicas[num], |
| xla::ShapeUtil::GetLeafCount(shape))); |
| } |
| builder->ClearSharding(); |
| } |
| } |
| } |
| |
| xla::XlaOp return_value; |
| for (auto& inst : *block) |
| if (failed(Lower(&inst, is_entry_function, ret_shardings, builder, |
| &lowering, &return_value))) |
| return failure(); |
| |
| // Build the XlaComputation and check for failures. |
| auto computation_or = |
| return_value.valid() ? builder->Build(return_value) : builder->Build(); |
| if (!computation_or.ok()) { |
| block->back().emitError( |
| llvm::Twine(computation_or.status().error_message())); |
| return failure(); |
| } |
| *result = std::move(computation_or.ValueOrDie()); |
| return success(); |
| } |
| |
| LogicalResult ConvertToHloModule::LowerRegionAsComputation( |
| mlir::Region* region, xla::XlaComputation* func, |
| llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands, |
| bool ensure_single_arg) { |
| std::unique_ptr<xla::XlaBuilder> builder = |
| module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++)); |
| return LowerBasicBlockAsFunction(®ion->front(), builder.get(), |
| /*is_entry_function=*/false, |
| /*ensure_single_arg*/ ensure_single_arg, |
| /*entry_args_same_across_replicas=*/{}, |
| /*arg_shardings=*/{}, /*ret_shardings=*/{}, |
| func, implicit_operands); |
| } |
| |
| void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding, |
| int arg_index, int32_t shape_index, |
| int32_t padding_arg_index, |
| bool use_tuple_args) { |
| auto* entry = binding->add_entries(); |
| entry->set_target_param_dim_num(shape_index); |
| if (use_tuple_args) { |
| entry->set_target_param_num(0); |
| entry->add_target_param_index(arg_index); |
| entry->set_dynamic_param_num(0); |
| entry->add_dynamic_param_index(padding_arg_index); |
| } else { |
| entry->set_target_param_num(arg_index); |
| entry->set_dynamic_param_num(padding_arg_index); |
| } |
| } |
| |
| // Runs the PrepareForExport pass on the ModuleOp. |
| Status PrepareForExport(mlir::ModuleOp module) { |
| // Prepare for export to XLA HLO. |
| mlir::PassManager pm(module.getContext()); |
| pm.addNestedPass<mlir::func::FuncOp>(mhlo::CreatePrepareForExport()); |
| if (failed(pm.run(module))) |
| return tensorflow::errors::Internal("Unable to optimize for XLA export"); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| } // namespace |
| |
| Status ConvertRegionToComputation(mlir::Region* region, |
| xla::XlaComputation* func, |
| MlirToHloConversionOptions options) { |
| mlir::ModuleOp module; |
| xla::XlaBuilder module_builder("main"); |
| ConvertToHloModule converter(module, module_builder, true, true, {}, options); |
| if (failed(converter.LowerRegionAsComputation(region, func))) |
| return tensorflow::errors::Internal( |
| "failed to convert region to computation"); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| Status ConvertMlirHloToHlo( |
| mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, |
| bool return_tuple, |
| const tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns |
| shape_determination_fns, |
| MlirToHloConversionOptions options) { |
| TF_RETURN_IF_ERROR(PrepareForExport(module)); |
| mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); |
| xla::XlaBuilder module_builder("main"); |
| ConvertToHloModule converter(module, module_builder, use_tuple_args, |
| return_tuple, shape_determination_fns, options); |
| if (failed(converter.Run())) return diag_handler.ConsumeStatus(); |
| auto hlo_module = converter.ConsumeMainProto(); |
| StringRef module_name = module.getName() ? *module.getName() : "main"; |
| hlo_module.set_name(module_name.str()); |
| hlo_proto->mutable_hlo_module()->Swap(&hlo_module); |
| return ::tensorflow::OkStatus(); |
| } |
| |
| Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, |
| llvm::ArrayRef<xla::XlaOp> xla_params, |
| std::vector<xla::XlaOp>& returns, |
| MlirToHloConversionOptions options) { |
| auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>(); |
| TF_RETURN_IF_ERROR(PrepareForExport(module)); |
| ConvertToHloModule converter(module, builder, |
| /*use_tuple_args=*/false, /*return_tuple=*/false, |
| /*shape_determination_fns=*/{}, options); |
| |
| ConvertToHloModule::ValueLoweringMap lowering; |
| // xla_params should only include non-constant parameters the block arguments |
| // correspond to. |
| if (xla_params.size() != block.getArguments().size()) |
| return tensorflow::errors::Internal("xla_params size (", xla_params.size(), |
| ") != block arguments size (", |
| block.getArguments().size(), ")"); |
| for (BlockArgument& arg : block.getArguments()) { |
| auto num = arg.getArgNumber(); |
| lowering[arg] = xla_params[num]; |
| } |
| |
| mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); |
| for (auto& inst : block) { |
| if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) { |
| returns.resize(inst.getNumOperands()); |
| for (OpOperand& ret : inst.getOpOperands()) { |
| unsigned index = ret.getOperandNumber(); |
| xla::XlaOp operand; |
| if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst))) |
| return diag_handler.ConsumeStatus(); |
| returns[index] = operand; |
| } |
| } else { |
| xla::XlaOp return_value; |
| if (failed(converter.Lower(&inst, /*is_entry_function=*/true, |
| /*ret_shardings=*/{}, &builder, &lowering, |
| &return_value))) |
| return diag_handler.ConsumeStatus(); |
| } |
| } |
| |
| return ::tensorflow::OkStatus(); |
| } |
| |
| } // namespace mlir |