blob: 2e58bf23c49f995b31d953e64982172e9d551dca [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include <memory>
#include <string>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using ::stream_executor::port::StatusOr;
using ::tensorflow::int16;
using ::tensorflow::int32;
using ::tensorflow::int64;
using ::tensorflow::int8;
using ::tensorflow::uint16;
using ::tensorflow::uint32;
using ::tensorflow::uint64;
using ::tensorflow::uint8;
constexpr char kPaddingMapAttr[] = "mhlo.padding_map";
constexpr char kShapeIndicesAttr[] = "shape_indices";
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
constexpr char kShardingAttr[] = "mhlo.sharding";
constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas";
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
// as a pointer and there is otherwise no way to avoid a memory leak.
template <typename T>
T Unwrap(T t) {
return t;
}
template <typename T>
T* Unwrap(const std::unique_ptr<T>& t) {
return t.get();
}
static mlir::LogicalResult GetXlaOp(
mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
xla::XlaOp* result, mlir::Operation* op) {
auto iter = val_map.find(val);
if (iter == val_map.end()) {
return op->emitOpError(
"requires all operands to be defined in the parent region for export");
}
*result = iter->second;
return mlir::success();
}
// Convert APInt into an int.
// TODO(hpucha): This should be consolidated into a general place.
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
static uint32_t Convertuint32_t(uint32_t i) { return i; }
static uint64_t Convertuint64_t(uint64_t i) { return i; }
// Convert APFloat to double.
static double ConvertAPFloat(llvm::APFloat value) {
const auto& semantics = value.getSemantics();
bool losesInfo = false;
if (&semantics != &llvm::APFloat::IEEEdouble())
value.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return value.convertToDouble();
}
static inline bool Convertbool(bool value) { return value; }
static absl::string_view ConvertStringRef(mlir::StringRef value) {
return {value.data(), value.size()};
}
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
auto values = attr.getValues<int64>();
return {values.begin(), values.end()};
}
static std::vector<int64> ConvertDenseIntAttr(
llvm::Optional<mlir::DenseIntElementsAttr> attr) {
if (!attr) return {};
return ConvertDenseIntAttr(*attr);
}
// Converts the broadcast_dimensions attribute into a vector of dimension
// numbers (empty if the attribute is absent).
static std::vector<int64> Convert_broadcast_dimensions(
llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
if (!broadcast_dimensions.hasValue()) return {};
return ConvertDenseIntAttr(*broadcast_dimensions);
}
// Converts StringRef to xla FftType enum
static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
xla::FftType fft_type_enum;
// Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
// call below should never return false.
if (!FftType_Parse(std::string(fft_type_str), &fft_type_enum))
return xla::FftType::FFT;
return fft_type_enum;
}
// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding
// and source-target pairs are defined in HLO.
static std::vector<std::pair<int64, int64>> Convert_Nx2_attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr) {
if (!optional_attr.hasValue()) return {};
mlir::DenseIntElementsAttr attr = *optional_attr;
auto it = attr.getValues<int64>().begin();
std::vector<std::pair<int64, int64>> out(attr.getNumElements() / 2);
for (auto& item : out) {
int64 first = *it;
++it;
int64 second = *it;
++it;
item = {first, second};
}
return out;
}
static std::vector<std::pair<int64, int64>> Convert_padding(
llvm::Optional<mlir::DenseIntElementsAttr> padding) {
return Convert_Nx2_attribute(padding);
}
static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
return Convert_Nx2_attribute(source_target_pairs);
}
static std::vector<xla::ReplicaGroup> Convert_replica_groups(
mlir::DenseIntElementsAttr groups) {
uint64_t num_groups = groups.getType().getDimSize(0);
uint64_t group_size = groups.getType().getDimSize(1);
std::vector<xla::ReplicaGroup> result;
result.reserve(num_groups);
for (uint64_t i = 0; i < num_groups; ++i) {
xla::ReplicaGroup group;
for (uint64_t j = 0; j < group_size; ++j) {
group.add_replica_ids(groups.getValue<int64_t>({i, j}));
}
result.push_back(group);
}
return result;
}
// Converts StringRef to xla Transpose enum.
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
llvm::StringRef transpose_str) {
xla::TriangularSolveOptions::Transpose transpose_enum;
// Illegal tanspose string would be caught by the verifier, so
// 'Transpose_Parse' call below should never return false.
if (!xla::TriangularSolveOptions::Transpose_Parse(std::string(transpose_str),
&transpose_enum))
return xla::TriangularSolveOptions::NO_TRANSPOSE;
return transpose_enum;
}
#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
static std::vector<int64> Convert_##attribute( \
llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
return ConvertDenseIntAttr(attribute); \
}
I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
I64_ELEMENTS_ATTR_TO_VECTOR(strides);
I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
#undef I64_ELEMENTS_ATTR_TO_VECTOR
static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
return {values.begin(), values.end()};
}
// Converts the precision config array of strings attribute into the
// corresponding XLA proto. All the strings are assumed to be valid names of the
// Precision enum. This should have been checked in the op verify method.
static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
if (!optional_precision_config_attr.hasValue()) return nullptr;
auto precision_config = absl::make_unique<xla::PrecisionConfig>();
for (auto attr : optional_precision_config_attr.getValue()) {
xla::PrecisionConfig::Precision p;
auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
// TODO(jpienaar): Update this to ensure this is captured by verify.
if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
precision_config->add_operand_precision(p);
} else {
auto* context = attr.getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "unexpected operand precision " << operand_precision;
return nullptr;
}
}
return precision_config;
}
static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
xla::DotDimensionNumbers dot_dimension_numbers;
auto rhs_contracting_dimensions =
dot_dimension_numbers_attr.rhs_contracting_dimensions()
.cast<mlir::DenseIntElementsAttr>();
auto lhs_contracting_dimensions =
dot_dimension_numbers_attr.lhs_contracting_dimensions()
.cast<mlir::DenseIntElementsAttr>();
auto rhs_batch_dimensions =
dot_dimension_numbers_attr.rhs_batching_dimensions()
.cast<mlir::DenseIntElementsAttr>();
auto lhs_batch_dimensions =
dot_dimension_numbers_attr.lhs_batching_dimensions()
.cast<mlir::DenseIntElementsAttr>();
for (const auto& val : rhs_contracting_dimensions) {
dot_dimension_numbers.add_rhs_contracting_dimensions(val.getSExtValue());
}
for (const auto& val : lhs_contracting_dimensions) {
dot_dimension_numbers.add_lhs_contracting_dimensions(val.getSExtValue());
}
for (const auto& val : rhs_batch_dimensions) {
dot_dimension_numbers.add_rhs_batch_dimensions(val.getSExtValue());
}
for (const auto& val : lhs_batch_dimensions) {
dot_dimension_numbers.add_lhs_batch_dimensions(val.getSExtValue());
}
return dot_dimension_numbers;
}
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
mlir::mhlo::ConvDimensionNumbers input) {
return xla::ConvertConvDimensionNumbers(input);
}
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
xla::ChannelHandle channel_handle;
channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
ConvertAPInt(attr.type().getValue())));
return channel_handle;
}
// Converts the comparison_direction string attribute into the XLA enum. The
// string is assumed to correspond to exactly one of the allowed strings
// representing the enum. This should have been checked in the op verify method.
static xla::ComparisonDirection Convert_comparison_direction(
llvm::StringRef comparison_direction_string) {
return xla::StringToComparisonDirection(comparison_direction_string.str())
.ValueOrDie();
}
static xla::GatherDimensionNumbers Convert_dimension_numbers(
mlir::mhlo::GatherDimensionNumbers input) {
xla::GatherDimensionNumbers output;
auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
std::copy(offset_dims.begin(), offset_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_offset_dims()));
auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_collapsed_slice_dims()));
auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
std::copy(start_index_map.begin(), start_index_map.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_start_index_map()));
output.set_index_vector_dim(
ConvertAPInt(input.index_vector_dim().getValue()));
return output;
}
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
mlir::mhlo::ScatterDimensionNumbers input) {
xla::ScatterDimensionNumbers output;
auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
std::copy(update_window_dims.begin(), update_window_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_update_window_dims()));
auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_inserted_window_dims()));
auto scatter_dims_to_operand_dims =
ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
std::copy(scatter_dims_to_operand_dims.begin(),
scatter_dims_to_operand_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_scatter_dims_to_operand_dims()));
output.set_index_vector_dim(
ConvertAPInt(input.index_vector_dim().getValue()));
return output;
}
// Extracts sharding from attribute string.
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
llvm::StringRef sharding) {
xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
return sharding_proto;
}
// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns absl::nullopt.
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
mlir::Operation* op) {
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
if (!sharding) return absl::nullopt;
return CreateOpShardingFromStringRef(sharding.getValue());
}
// Returns a FrontendAttributes proto from the "frontend_attributes" attribute
// of the op. An empty FrontendAttributes proto is returned if an op does not
// have frontend attributes.
static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
mlir::Operation* op) {
xla::FrontendAttributes frontend_attributes;
auto frontend_attributes_dict =
op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
if (!frontend_attributes_dict) return frontend_attributes;
for (const auto& attr : frontend_attributes_dict)
if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
frontend_attributes.mutable_map()->insert(
{attr.first.str(), value_str_attr.getValue().str()});
return frontend_attributes;
}
// Returns a OpMetadata proto based on the location of the op. If the location
// is unknown, an empty proto is returned. `op_name` are populated with the op
// location (converted). FileLineColLoc locations are populated by taking the
// file name and line number, and populating `source_file` and `source_line`
// respectively.
static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
xla::OpMetadata metadata;
if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
std::string name = mlir::GetNameFromLoc(op->getLoc());
mlir::LegalizeNodeName(name);
metadata.set_op_name(name);
if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
metadata.set_source_file(file_line_col_loc.getFilename().str());
metadata.set_source_line(file_line_col_loc.getLine());
}
return metadata;
}
// Checks if all shardings are set.
static bool AllOptionalShardingsAreSet(
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
return llvm::all_of(shardings,
[](const absl::optional<xla::OpSharding>& sharding) {
return sharding.has_value();
});
}
// Extracts argument and result shardings from function.
static void ExtractShardingsFromFunction(
mlir::FuncOp function,
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
arg_shardings->resize(function.getNumArguments(),
absl::optional<xla::OpSharding>());
for (int i = 0, end = function.getNumArguments(); i < end; ++i)
if (auto sharding =
function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
(*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
ret_shardings->resize(function.getNumResults(),
absl::optional<xla::OpSharding>());
for (int i = 0, end = function.getNumResults(); i < end; ++i)
if (auto sharding =
function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
(*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
}
namespace mlir {
namespace {
class ConvertToHloModule {
public:
using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
// If use_tuple_args is true, then the entry function's arguments are
// converted to a tuple and passed as a single parameter.
// Similarly, if return tuple is true, then the entry function's return values
// are converted to a tuple even when there is only a single return value.
// Multiple return values are always converted to a tuple and returned as a
// single value.
explicit ConvertToHloModule(
mlir::ModuleOp module, xla::XlaBuilder& module_builder,
bool use_tuple_args, bool return_tuple,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options)
: module_(module),
module_builder_(module_builder),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple),
shape_representation_fn_(shape_representation_fn),
options_(options) {
if (!shape_representation_fn_)
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
}
// Perform the lowering to XLA. This function returns failure if an error was
// encountered.
//
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
LogicalResult Run() {
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
if (!main)
return module_.emitError(
"conversion requires module with `main` function");
for (auto func : module_.getOps<FuncOp>()) {
if (func.empty()) continue;
if (failed(RunOnFunction(func))) return failure();
}
return success();
}
// Lower a specific function to HLO.
LogicalResult RunOnFunction(mlir::FuncOp f);
// Lower a `mlir::Region` to a `XlaComputation`
LogicalResult LowerRegionAsComputation(mlir::Region* region,
xla::XlaComputation* func);
// Lower a single `Block` to a `XlaComputation`
LogicalResult LowerBasicBlockAsFunction(
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
const std::vector<bool>& entry_args_same_across_replicas,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaComputation* result);
::xla::HloModuleProto ConsumeMainProto() {
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
// This is an invariant check as Run returns failure if there is no main
// function and so the main proto shouldn't be consumed in that case.
CHECK(main) << "requires module to have main function"; // Crash Ok.
return lowered_computation_[main].proto();
}
// Lower function call to HLO call instruction
LogicalResult LowerFunctionCall(
mlir::CallOp call_op, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering);
LogicalResult Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaOp* return_value);
private:
LogicalResult SetEntryTupleShapesAndLeafReplication(
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
std::vector<bool>* leaf_replication);
LogicalResult SetEntryTupleShardings(
Block* block, xla::XlaBuilder* builder,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
// The module being lowered.
mlir::ModuleOp module_;
// The top-level XlaBuilder.
xla::XlaBuilder& module_builder_;
// Map between function and lowered computation.
FunctionLoweringMap lowered_computation_;
// Whether the entry function should take a single tuple as input.
bool use_tuple_args_;
// Whether to always return a tuple.
bool return_tuple_;
// Shape representation function to determine entry function argument and
// result shapes.
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
// Unique suffix to give to the name of the next lowered region.
size_t region_id_ = 0;
MlirToHloConversionOptions options_;
};
} // namespace
} // namespace mlir
namespace {
struct OpLoweringContext {
llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
mlir::ConvertToHloModule* converter;
xla::XlaBuilder* builder;
};
llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
OpLoweringContext ctx) {
llvm::SmallVector<xla::XlaOp, 4> ops;
for (mlir::Value value : values) {
ops.push_back((*ctx.values)[value]);
}
return ops;
}
} // namespace
namespace mlir {
namespace mhlo {
namespace {
LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation computation;
if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
&computation))) {
return failure();
}
auto replica_groups = Convert_replica_groups(op.replica_groups());
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (!op.channel_id().hasValue()) {
value_map[op] = xla::AllReduce(operand, computation, replica_groups,
/*channel_id=*/absl::nullopt);
return success();
}
auto channel_id = Convert_channel_handle(op.channel_id().getValue());
value_map[op] =
xla::AllReduce(operand, computation, replica_groups, channel_id);
return success();
}
LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
value_map[op] = xla::BitcastConvertType(
operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
return success();
}
LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type) return failure();
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
value_map[op] =
BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
Convert_broadcast_dimensions(op.broadcast_dimensions()));
return success();
}
LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
// This op has no expression in the legacy export format.
return failure();
}
LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
// This op has no expression in the legacy export format.
return failure();
}
LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
// This op has no expression in the legacy export format.
return failure();
}
LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
xla::XlaComputation true_branch;
xla::XlaComputation false_branch;
auto& value_map = *ctx.values;
if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
&true_branch)) ||
failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
&false_branch))) {
return failure();
}
xla::XlaOp pred, true_arg, false_arg;
if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op)))
return failure();
if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op)))
return failure();
value_map[op] =
xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
return success();
}
LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
OperandRange operands = op.branch_operands();
MutableArrayRef<Region> branches = op.branches();
llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
std::vector<xla::XlaComputation> computations(branches.size());
std::vector<xla::XlaComputation*> computations_p(branches.size());
for (unsigned i = 0; i < branches.size(); ++i) {
xla::XlaOp operand;
if (failed(GetXlaOp(operands[i], value_map, &operand, op)))
return failure();
branch_operands[i] = operand;
computations_p[i] = &computations[i];
if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
computations_p[i])))
return failure();
}
xla::XlaOp index;
if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands);
if (op.getNumResults() == 1) {
value_map[op.getResult(0)] = result;
} else {
for (auto item : llvm::enumerate(op.getResults())) {
value_map[item.value()] = xla::GetTupleElement(result, item.index());
}
}
return success();
}
// Specialize CompareOp export to set broadcast_dimensions argument.
mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp lhs, rhs;
if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
auto dir = Convert_comparison_direction(op.comparison_direction());
auto type_attr = op.compare_typeAttr();
xla::XlaOp xla_result;
if (type_attr) {
auto type =
xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
} else {
xla_result = xla::Compare(lhs, rhs, dir);
}
value_map[op] = xla_result;
return mlir::success();
}
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
// XLA client builder API does not support generating convolution instructions
// with window reversal.
if (op.hasWindowReversal()) return failure();
auto& value_map = *ctx.values;
xla::XlaOp lhs, rhs;
if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
xla::XlaOp xla_result = xla::ConvGeneralDilated(
lhs, rhs, Convert_window_strides(op.window_strides()),
Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
Convert_rhs_dilation(op.rhs_dilation()),
Convert_dimension_numbers(op.dimension_numbers()),
Convertuint64_t(op.feature_group_count()),
Convertuint64_t(op.batch_group_count()),
Unwrap(Convert_precision_config(op.precision_config())));
value_map[op] = xla_result;
return mlir::success();
}
LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
value_map[op] = xla::ConvertElementType(
operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
return success();
}
LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
// XLA client builder API does not support generating custom call instructions
// with side effect.
if (op.has_side_effect() || op.getNumResults() != 1) return failure();
Value result = op.getResult(0);
auto& value_map = *ctx.values;
value_map[result] = xla::CustomCall(
ctx.builder, std::string(op.call_target_name()), GetTuple(op.args(), ctx),
xla::TypeToShape(result.getType()), std::string(op.backend_config()));
return success();
}
LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) {
xla::QuantizedRange range(ConvertAPFloat(op.min_range()),
ConvertAPFloat(op.max_range()));
auto& value_map = *ctx.values;
xla::XlaOp input;
if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure();
auto casted = xla::ConvertElementType(input, xla::U32);
if (op.is_16bits()) {
value_map[op] = xla::Dequantize<uint16>(
casted, range, ConvertStringRef(op.mode()), op.transpose_output());
} else {
value_map[op] = xla::Dequantize<uint8>(
casted, range, ConvertStringRef(op.mode()), op.transpose_output());
}
return success();
}
LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp token;
if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
// The shape argument expected by the xla client API is the type of the first
// element in the result tuple.
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type),
std::string(op.infeed_config()));
return success();
}
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
op.iota_dimension());
return success();
}
LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation computation;
if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
&computation))) {
return failure();
}
value_map[op] = xla::Map(ctx.builder, GetTuple(op.operands(), ctx),
computation, Convert_dimensions(op.dimensions()));
return success();
}
LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand, token;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
value_map[op] = xla::OutfeedWithToken(
operand, token, xla::TypeToShape(op.operand().getType()),
std::string(op.outfeed_config()));
return success();
}
LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::PaddingConfig padding_config;
auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(edge_padding_low[i]);
dims->set_edge_padding_high(edge_padding_high[i]);
dims->set_interior_padding(interior_padding[i]);
}
xla::XlaOp operand, padding_value;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
return failure();
value_map[op] = xla::Pad(operand, padding_value, padding_config);
return success();
}
LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
xla::XlaOp token;
if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
if (op.is_host_transfer()) {
value_map[op] = xla::RecvFromHost(token, xla::TypeToShape(result_type),
Convert_channel_handle(op.channel_id()));
return success();
}
value_map[op] = xla::RecvWithToken(token, xla::TypeToShape(result_type),
Convert_channel_handle(op.channel_id()));
return success();
}
LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation body;
if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
return failure();
}
xla::XlaOp result =
xla::Reduce(ctx.builder, GetTuple(op.operands(), ctx),
GetTuple(op.init_values(), ctx), body,
Convert_broadcast_dimensions(op.dimensions()));
if (op.getNumResults() == 1) {
value_map[op.getResult(0)] = result;
} else {
for (auto item : llvm::enumerate(op.getResults())) {
value_map[item.value()] = xla::GetTupleElement(result, item.index());
}
}
return success();
}
LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation body;
if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
return failure();
}
xla::XlaOp operand, init_value;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
return failure();
value_map[op] = xla::ReduceWindowWithGeneralPadding(
operand, init_value, body, ConvertDenseIntAttr(op.window_dimensions()),
ConvertDenseIntAttr(op.window_strides()),
ConvertDenseIntAttr(op.base_dilations()),
ConvertDenseIntAttr(op.window_dilations()),
Convert_padding(op.padding()));
return success();
}
LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
value_map[op] =
xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
return success();
}
LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
// Failure on purpose because `mhlo::ReturnOp` will be handled by
// special purpose logic in `ConvertToHloModule::Lower`.
return failure();
}
LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
auto result = op.getResult();
auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
auto xla_result = xla::RngBitGenerator(
static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
xla::TypeToShape(result.getType()).tuple_shapes(1));
value_map[result] = xla_result;
return mlir::success();
}
LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp mu, sigma;
if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure();
if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure();
value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType()));
return success();
}
LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp a, b;
if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
return success();
}
LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation update_computation;
if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
&update_computation))) {
return failure();
}
xla::ScatterDimensionNumbers dimension_numbers =
Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
xla::XlaOp operand, scatter_indices, updates;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
return failure();
if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure();
value_map[op] = xla::Scatter(operand, scatter_indices, updates,
update_computation, dimension_numbers,
op.indices_are_sorted(), op.unique_indices());
return success();
}
LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation select;
xla::XlaComputation scatter;
if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
failed(
ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
return failure();
}
xla::XlaOp operand, source, init_value;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
return failure();
value_map[op] = xla::SelectAndScatterWithGeneralPadding(
operand, select, ConvertDenseIntAttr(op.window_dimensions()),
ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
source, init_value, scatter);
return success();
}
LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand, token;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
if (op.is_host_transfer()) {
value_map[op] = xla::SendToHost(operand, token,
xla::TypeToShape(op.operand().getType()),
Convert_channel_handle(op.channel_id()));
return success();
}
value_map[op] = xla::SendWithToken(operand, token,
Convert_channel_handle(op.channel_id()));
return success();
}
LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
xla::XlaComputation comparator;
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
&comparator)))
return failure();
auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension(), op.is_stable());
auto& value_map = *ctx.values;
auto shape_or = sorted.builder()->GetShape(sorted);
if (!shape_or.ok()) {
return op.emitError(shape_or.status().ToString());
}
xla::Shape& shape = shape_or.ValueOrDie();
if (!shape.IsTuple()) {
value_map[op.getResult(0)] = sorted;
return success();
}
// MLIR's sort supports multiple returns, untuple all the results of XLA's.
for (auto it : llvm::enumerate(op.getResults())) {
value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
}
return success();
}
LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
xla::Trace(std::string(op.tag()), operand);
return success();
}
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
// operands.
return failure();
}
LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
xla::XlaComputation condition;
xla::XlaComputation body;
auto& value_map = *ctx.values;
if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body)) ||
failed(ctx.converter->LowerRegionAsComputation(&op.cond(), &condition))) {
return failure();
}
xla::XlaOp operand;
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op)))
return failure();
value_map[op] = xla::While(condition, body, operand);
return success();
}
LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
if (!op.fusion_kind()) {
op.emitOpError() << "requires fusion kind for HLO translation";
return failure();
}
xla::XlaComputation fused_computation;
if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
&fused_computation)))
return failure();
auto& values = *ctx.values;
llvm::SmallVector<xla::XlaOp, 4> operands;
for (auto operand : op.operands()) operands.push_back(values[operand]);
xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
ctx.builder, operands,
absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
fused_computation);
if (op.getNumResults() == 1) {
values[op.getResult(0)] = fusion;
} else {
for (auto item : llvm::enumerate(op.getResults())) {
values[item.value()] = xla::GetTupleElement(fusion, item.index());
}
}
return success();
}
LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp operand;
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
ctx.builder, operand, xla::TypeToShape(op.getType()));
return success();
}
} // namespace
} // namespace mhlo
} // namespace mlir
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
namespace mlir {
namespace {
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
xla::Layout layout) {
if (attr.isa<OpaqueElementsAttr>())
return tensorflow::errors::Unimplemented(
"Opaque elements attr not supported");
xla::Shape shape = xla::TypeToShape(attr.getType());
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
case xla_type: { \
xla::Array<cpp_type> source_data(shape.dimensions()); \
source_data.SetValues(attr.getValues<cpp_type>()); \
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
}
switch (shape.element_type()) {
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
default:
return tensorflow::errors::Internal(absl::StrCat(
"Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
}
#undef ELEMENTS_ATTR_TO_LITERAL
}
xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
if (auto attr = GetLayoutFromMlirHlo(op)) {
llvm::SmallVector<int64, 4> minor_to_major;
DCHECK_EQ(rank, attr.size());
minor_to_major.reserve(attr.size());
for (const llvm::APInt& i : attr) {
minor_to_major.push_back(i.getZExtValue());
}
return xla::LayoutUtil::MakeLayout(minor_to_major);
}
return xla::LayoutUtil::MakeDescendingLayout(rank);
}
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaOp* return_value) {
*return_value = xla::XlaOp();
// See MlirToHloConversionOptions for more about layouts.
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
if (options_.propagate_layouts) {
auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
->mutable_shape();
if (shape->tuple_shapes().empty())
*shape->mutable_layout() =
ExtractLayout(inst, shape->dimensions().size()).ToProto();
}
};
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
if (inst->getNumResults() == 1) {
auto iter = value_lowering->find(inst->getResult(0));
if (iter == value_lowering->end()) {
inst->emitOpError(
"inst has a result, but it's not found in value_lowering");
return failure();
}
propagate_layouts(inst, iter->second);
}
return success();
}
auto& value_map = *value_lowering;
ElementsAttr const_attr;
if (auto call_op = dyn_cast<mlir::CallOp>(inst)) {
return LowerFunctionCall(call_op, builder, &value_map);
}
if (auto op = dyn_cast<mlir::TensorCastOp>(inst)) {
Value operand = op.getOperand();
auto ty = operand.getType().dyn_cast<ShapedType>();
// If this was a cast from a static shaped tensors, then it is a noop for
// export to HLO and we can use the operand.
if (!ty || !ty.hasStaticShape()) {
inst->emitOpError()
<< "requires static shaped operand for HLO translation";
return failure();
}
xla::XlaOp xla_operand;
if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
return failure();
value_map[op.getResult()] = xla_operand;
propagate_layouts(inst, xla_operand);
return success();
}
if (matchPattern(inst, m_Constant(&const_attr))) {
xla::Layout layout;
layout = ExtractLayout(inst, const_attr.getType().getRank());
auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
if (!literal_or.ok())
return inst->emitError(literal_or.status().ToString());
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
value_map[inst->getResult(0)] = constant;
return success();
}
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there is a single value
// returned, then return it directly, else create a tuple and return.
unsigned num_return_values = inst->getNumOperands();
if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
const bool has_ret_shardings =
!ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
std::vector<xla::XlaOp> returns(num_return_values);
for (OpOperand& ret : inst->getOpOperands()) {
unsigned index = ret.getOperandNumber();
xla::XlaOp operand;
if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
return failure();
returns[index] = operand;
if (!is_entry_function || !has_ret_shardings) continue;
xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
StatusOr<xla::XlaOp> reshape =
tensorflow::ReshapeWithCorrectRepresentationAndSharding(
builder, returns[index], return_shape, shape_representation_fn_,
ret_shardings[index], /*fast_mem=*/false);
if (!reshape.ok())
return inst->emitError() << reshape.status().error_message();
returns[index] = reshape.ValueOrDie();
}
if (has_ret_shardings) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (auto& ret_sharding : ret_shardings)
*sharding.add_tuple_shardings() = ret_sharding.value();
builder->SetSharding(sharding);
}
*return_value = xla::Tuple(builder, returns);
builder->ClearSharding();
} else if (num_return_values == 1) {
xla::XlaOp operand;
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
return failure();
*return_value = operand;
}
return success();
}
inst->emitOpError() << "can't be translated to XLA HLO";
return failure();
}
LogicalResult ConvertToHloModule::LowerFunctionCall(
mlir::CallOp call_op, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering) {
auto& value_map = *value_lowering;
mlir::FuncOp callee = module_.lookupSymbol<mlir::FuncOp>(call_op.callee());
if (failed(RunOnFunction(callee))) return failure();
std::vector<xla::XlaOp> operands;
for (auto operand : call_op.getOperands()) {
xla::XlaOp xla_operand;
if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
return failure();
operands.push_back(xla_operand);
}
// Each call to xla::Call would insert a copy of the computation to
// the HLO. Thus each callsite would have a unique callee in the
// exported HLO. HLO syntactically does not require all calls to have unique
// callees, but eventually before lowering call graph is "flattened" to
// make that true. This is done before lowering because buffer assignment
// needs this invariant.
xla::XlaOp call_result =
xla::Call(builder, lowered_computation_[callee], operands);
// Use GetTupleElement for multiple outputs
unsigned num_results = call_op.getNumResults();
if (num_results > 1) {
for (unsigned i = 0; i != num_results; ++i) {
value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
}
} else if (num_results == 1) {
value_map[call_op.getResult(0)] = call_result;
}
return success();
}
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
if (lowered_computation_.count(f)) return success();
if (!llvm::hasSingleElement(f)) {
return f.emitError("only single block Function supported");
}
// Create a sub-builder if this is not the main function.
std::unique_ptr<xla::XlaBuilder> builder_up;
bool entry_function = f.getName() == "main";
if (!entry_function)
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up;
xla::XlaComputation computation;
std::vector<bool> entry_args_same_across_replicas;
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
if (entry_function) {
bool any_arg_replicated = false;
entry_args_same_across_replicas.reserve(f.getNumArguments());
for (int64_t i = 0; i < f.getNumArguments(); ++i) {
auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kRepicationAttr);
entry_args_same_across_replicas.push_back(attr != nullptr);
any_arg_replicated |= entry_args_same_across_replicas.back();
// Pass the alias info to the builder so that it will build the alias info
// into the resulting HloModule.
auto aliasing_output =
f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
if (!aliasing_output) continue;
if (use_tuple_args_) {
builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
/*param_number=*/0, /*param_index=*/{i});
} else {
builder.SetUpAlias(/*output_index=*/{aliasing_output.getInt()},
/*param_number=*/i, /*param_index=*/{});
}
}
// Do not populate this field when nothing is replicated, since empty field
// means no replication. This avoids the need for unrelated tests to handle
// this field.
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
}
if (failed(LowerBasicBlockAsFunction(
&f.front(), &builder, entry_function, entry_args_same_across_replicas,
arg_shardings, ret_shardings, &computation))) {
return failure();
}
lowered_computation_[f] = std::move(computation);
return success();
}
LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
std::vector<bool>* leaf_replication) {
arg_shapes->reserve(block->getNumArguments());
leaf_replication->reserve(block->getNumArguments());
for (BlockArgument& arg : block->getArguments()) {
arg_shapes->push_back(xla::TypeToShape(arg.getType()));
xla::Shape& arg_shape = arg_shapes->back();
tensorflow::TensorShape arg_tensor_shape;
auto status =
tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
tensorflow::DataType dtype;
status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
/*use_fast_memory=*/false);
if (!arg_shape_status.ok())
return block->getParentOp()->emitError()
<< arg_shape_status.status().error_message();
arg_shape = std::move(arg_shape_status.ValueOrDie());
if (entry_args_same_across_replicas.empty()) continue;
for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
leaf_replication->push_back(
entry_args_same_across_replicas[arg.getArgNumber()]);
}
return success();
}
LogicalResult ConvertToHloModule::SetEntryTupleShardings(
Block* block, xla::XlaBuilder* builder,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
auto hlo_sharding =
xla::HloSharding::FromProto(arg_sharding.value().value());
if (!hlo_sharding.ok())
return block->getParentOp()->emitError()
<< hlo_sharding.status().error_message();
auto status = tensorflow::RewriteLayoutWithShardedShape(
hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
*sharding.add_tuple_shardings() = arg_sharding.value().value();
}
builder->SetSharding(sharding);
}
return success();
}
LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
const std::vector<bool>& entry_args_same_across_replicas,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaComputation* result) {
// Mapping from the Value to lowered XlaOp.
ValueLoweringMap lowering;
// If using tuples as input, then there is only one input parameter that is a
// tuple.
if (is_entry_function && use_tuple_args_) {
llvm::SmallVector<xla::Shape, 4> arg_shapes;
std::vector<bool> leaf_replication;
if (failed(SetEntryTupleShapesAndLeafReplication(
block, entry_args_same_across_replicas, &arg_shapes,
&leaf_replication)))
return failure();
if (failed(
SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
return failure();
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
auto tuple =
xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
builder->ClearSharding();
for (BlockArgument& arg : block->getArguments())
lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
} else {
for (BlockArgument& arg : block->getArguments()) {
auto num = arg.getArgNumber();
xla::Shape shape = xla::TypeToShape(arg.getType());
if (entry_args_same_across_replicas.empty()) {
lowering[arg] =
xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
} else {
lowering[arg] = xla::Parameter(
builder, num, shape, absl::StrCat("Arg_", num),
std::vector<bool>(entry_args_same_across_replicas[num],
xla::ShapeUtil::GetLeafCount(shape)));
}
}
}
xla::XlaOp return_value;
for (auto& inst : *block)
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
&lowering, &return_value)))
return failure();
// Build the XlaComputation and check for failures.
auto computation_or =
return_value.valid() ? builder->Build(return_value) : builder->Build();
if (!computation_or.ok()) {
block->back().emitError(
llvm::Twine(computation_or.status().error_message()));
return failure();
}
*result = std::move(computation_or.ValueOrDie());
return success();
}
LogicalResult ConvertToHloModule::LowerRegionAsComputation(
mlir::Region* region, xla::XlaComputation* func) {
std::unique_ptr<xla::XlaBuilder> builder =
module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
return LowerBasicBlockAsFunction(
&region->front(), builder.get(),
/*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
/*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
}
std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
return llvm::formatv(
"requires '{0}' array attribute in '{1}' dict at arg {2}",
attr_name, kPaddingMapAttr, index)
.str();
}
std::string PaddingMapMismatchedArraySizeMsg(int arg_index,
int shape_indices_size,
int padding_arg_indices_size) {
return llvm::formatv(
"requires '{0}' and '{1}' array attributes in '{2}' dic at arg "
"{3} to be of the same size, got sizes {4} and {5}",
kShapeIndicesAttr, kPaddingArgIndicesAttr, kPaddingMapAttr,
arg_index, shape_indices_size, padding_arg_indices_size)
.str();
}
std::string PaddingMapBadIntAttrMsg(llvm::StringRef attr_name, int arg_index,
int element_index) {
return llvm::formatv(
"requires element {0} in '{1}' array of '{2}' dict at arg {3} "
"to be an int attribute",
element_index, attr_name, kPaddingMapAttr, arg_index)
.str();
}
std::string PaddingMapBadIndexMsg(llvm::StringRef attr_name, int arg_index,
int element_index, int max, int32_t value) {
return llvm::formatv(
"requires element {0} in '{1}' array of '{2}' dict at arg {3} "
"to be in range [0, {4}), got {5}",
element_index, attr_name, kPaddingMapAttr, arg_index, max, value)
.str();
}
std::string PaddingMapNegativeShapeIndexMsg(int arg_index, int element_index,
int32_t value) {
return llvm::formatv(
"requires element {0} in '{1}' array of '{2}' dict at arg {3} to "
"be non-negative, got {4}",
element_index, kShapeIndicesAttr, kPaddingMapAttr, arg_index,
value)
.str();
}
std::string PaddingMapUniqueShapeIndexMsg(int arg_index, int element_index,
int32_t value) {
return llvm::formatv(
"requires elements in '{0}' array of '{1}' dict at arg {2} to be "
"unique, got duplicate element {3} at index {4}",
kShapeIndicesAttr, kPaddingMapAttr, arg_index, value,
element_index)
.str();
}
void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
int arg_index, int32_t shape_index,
int32_t padding_arg_index,
bool use_tuple_args) {
auto* entry = binding->add_entries();
entry->set_target_param_dim_num(shape_index);
if (use_tuple_args) {
entry->set_target_param_num(0);
entry->add_target_param_index(arg_index);
entry->set_dynamic_param_num(0);
entry->add_dynamic_param_index(padding_arg_index);
} else {
entry->set_target_param_num(arg_index);
entry->set_dynamic_param_num(padding_arg_index);
}
}
// Validates and populates dynamic parameter bindings from a module's entry
// function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto`
// `DynamicParameterBindingProto`.
LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module_proto,
bool use_tuple_args) {
auto entry_func = module.lookupSymbol<mlir::FuncOp>("main");
if (!entry_func) return success();
auto* dynamic_parameter_binding =
hlo_module_proto->mutable_dynamic_parameter_binding();
for (int i = 0, e = entry_func.getNumArguments(); i < e; ++i) {
auto padding_map_attr = entry_func.getArgAttr(i, kPaddingMapAttr);
if (!padding_map_attr) continue;
auto padding_map = padding_map_attr.dyn_cast<DictionaryAttr>();
if (!padding_map)
return entry_func.emitError() << "requires '" << kPaddingMapAttr
<< "' dict attribute at arg " << i;
auto shape_indices =
padding_map.get(kShapeIndicesAttr).dyn_cast_or_null<ArrayAttr>();
if (!shape_indices)
return entry_func.emitError(
PaddingMapBadArrayAttrMsg(kShapeIndicesAttr, i));
auto padding_arg_indices =
padding_map.get(kPaddingArgIndicesAttr).dyn_cast_or_null<ArrayAttr>();
if (!padding_arg_indices)
return entry_func.emitError(
PaddingMapBadArrayAttrMsg(kPaddingArgIndicesAttr, i));
if (shape_indices.size() != padding_arg_indices.size())
return entry_func.emitError(PaddingMapMismatchedArraySizeMsg(
i, shape_indices.size(), padding_arg_indices.size()));
llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
auto arg_type =
entry_func.getArgument(i).getType().dyn_cast<RankedTensorType>();
for (auto shape_and_padding : llvm::enumerate(llvm::zip(
shape_indices.getValue(), padding_arg_indices.getValue()))) {
const int element_index = shape_and_padding.index();
auto shape_index_attr =
std::get<0>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
if (!shape_index_attr)
return entry_func.emitError(
PaddingMapBadIntAttrMsg(kShapeIndicesAttr, i, element_index));
auto padding_arg_index_attr =
std::get<1>(shape_and_padding.value()).dyn_cast<IntegerAttr>();
if (!padding_arg_index_attr)
return entry_func.emitError(
PaddingMapBadIntAttrMsg(kPaddingArgIndicesAttr, i, element_index));
const int32_t shape_index = shape_index_attr.getInt();
if (arg_type && (shape_index < 0 || shape_index >= arg_type.getRank()))
return entry_func.emitError(
PaddingMapBadIndexMsg(kShapeIndicesAttr, i, element_index,
arg_type.getRank(), shape_index));
else if (shape_index < 0)
return entry_func.emitError(
PaddingMapNegativeShapeIndexMsg(i, element_index, shape_index));
if (!used_shape_indices.insert(shape_index).second)
return entry_func.emitError(
PaddingMapUniqueShapeIndexMsg(i, element_index, shape_index));
const int32_t padding_arg_index = padding_arg_index_attr.getInt();
if (padding_arg_index < 0 || padding_arg_index >= e)
return entry_func.emitError(PaddingMapBadIndexMsg(
kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
Type padding_arg_type =
entry_func.getArgument(padding_arg_index).getType();
if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
if (tensor_type.getRank() != 0)
return entry_func.emitError()
<< "requires arg " << padding_arg_index
<< " to be a scalar for use as a dynamic parameter";
if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
return entry_func.emitError()
<< "requires arg " << padding_arg_index
<< " to be of an int type for use as a dynamic parameter";
AddDynamicParameterBindingEntry(dynamic_parameter_binding, i, shape_index,
padding_arg_index, use_tuple_args);
}
}
return success();
}
} // namespace
Status ConvertRegionToComputation(mlir::Region* region,
xla::XlaComputation* func,
MlirToHloConversionOptions options) {
mlir::ModuleOp module;
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, true, true, {}, options);
if (failed(converter.LowerRegionAsComputation(region, func)))
return tensorflow::errors::Internal(
"failed to convert region to computation");
return Status::OK();
}
Status ConvertMlirHloToHlo(
mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
bool return_tuple,
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, use_tuple_args,
return_tuple, shape_representation_fn, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
if (failed(AddDynamicParameterBindings(
module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
return diag_handler.ConsumeStatus();
return Status::OK();
}
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
MlirToHloConversionOptions options) {
auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
ConvertToHloModule converter(module, builder,
/*use_tuple_args=*/false, /*return_tuple=*/false,
/*shape_representation_fn=*/nullptr, options);
ConvertToHloModule::ValueLoweringMap lowering;
if (xla_params.size() != block.getArguments().size())
return tensorflow::errors::Internal(
"xla_params size != block arguments size");
for (BlockArgument& arg : block.getArguments()) {
auto num = arg.getArgNumber();
lowering[arg] = xla_params[num];
}
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
for (auto& inst : block) {
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
returns.resize(inst.getNumOperands());
for (OpOperand& ret : inst.getOpOperands()) {
unsigned index = ret.getOperandNumber();
xla::XlaOp operand;
if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
return diag_handler.ConsumeStatus();
returns[index] = operand;
}
} else {
xla::XlaOp return_value;
if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
/*ret_shardings=*/{}, &builder, &lowering,
&return_value)))
return diag_handler.ConsumeStatus();
}
}
return Status::OK();
}
DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op) {
return op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
}
} // namespace mlir