For ops with `hasCustomHLOConverter` set in tablegen, generate a call to
`ExportXlaOp(op, ...)`.
Also fill in lowerings for xla_hlo.reduce and xla_hlo.broadcast_in_dim as
examples.
PiperOrigin-RevId: 274641888
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 43d3779..45bf845 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -73,6 +73,10 @@
return ConvertDenseIntAttr(permutation);
}
+static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
+ return {values.begin(), values.end()};
+}
+
// Converts the precision config array of strings attribute into the
// corresponding XLA proto. All the strings are assumed to be valid names of the
// Precision enum. This should have been checked in the op verify method.
@@ -170,6 +174,211 @@
return value.convertToDouble();
}
+namespace mlir {
+namespace {
+class ConvertToHloModule {
+ public:
+ using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
+ using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
+
+ explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
+ bool always_return_tuple)
+ : module_(module),
+ module_builder_("main"),
+ use_tuple_args_(use_tuple_args),
+ always_return_tuple_(always_return_tuple) {}
+
+ // Perform the lowering to XLA. This function returns failure if an error was
+ // encountered.
+ LogicalResult Run() {
+ for (auto func : module_.getOps<FuncOp>()) {
+ if (func.empty()) continue;
+ if (failed(RunOnFunction(func))) return failure();
+ }
+ return success();
+ }
+
+ // Lower a specific function to HLO.
+ LogicalResult RunOnFunction(mlir::FuncOp f);
+
+ // Lower a `mlir::Region` to a `XlaComputation`
+ LogicalResult LowerRegionAsComputation(mlir::Region* region,
+ xla::XlaComputation* func);
+
+ // Lower a single `Block` to a `XlaComputation`
+ LogicalResult LowerBasicBlockAsFunction(Block* block,
+ xla::XlaBuilder* builder,
+ xla::XlaComputation* result);
+
+ ::xla::HloModuleProto ConsumeMainProto() {
+ return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
+ .proto();
+ }
+
+ private:
+ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
+ ConvertToHloModule::ValueLoweringMap* value_lowering,
+ xla::XlaComputation* result);
+
+ // The module being lowered.
+ mlir::ModuleOp module_;
+
+ // The top-level XlaBuilder.
+ xla::XlaBuilder module_builder_;
+
+ // Map between function and lowered computation.
+ FunctionLoweringMap lowered_computation_;
+
+ // Whether the entry function should take a single tuple as input.
+ bool use_tuple_args_;
+
+ // Whether to always return a tuple.
+ bool always_return_tuple_;
+
+ // Unique suffix to give to the name of the next lowered region.
+ size_t region_id_ = 0;
+};
+
+} // namespace
+} // namespace mlir
+
+namespace {
+
+struct OpLoweringContext {
+ llvm::DenseMap<mlir::Value*, xla::XlaOp>* values;
+ mlir::ConvertToHloModule* converter;
+ xla::XlaBuilder* builder;
+};
+
+llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
+ OpLoweringContext ctx) {
+ llvm::SmallVector<xla::XlaOp, 4> ops;
+ for (mlir::Value* value : values) {
+ ops.push_back((*ctx.values)[value]);
+ }
+ return ops;
+}
+
+} // namespace
+
+namespace mlir {
+namespace xla_hlo {
+namespace {
+
+LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
+ auto type = op.getType().dyn_cast<RankedTensorType>();
+ if (!type) return failure();
+ auto& value_map = *ctx.values;
+ value_map[op] =
+ BroadcastInDim(value_map[op.operand()], Convert_ArrayRef(type.getShape()),
+ Convert_broadcast_dimensions(op.broadcast_dimensions()));
+ return success();
+}
+
+LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ConvOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(CopyOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(DynamicSliceOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(DynamicUpdateSliceOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { return failure(); }
+
+LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
+ auto& value_map = *ctx.values;
+ xla::XlaComputation body;
+ if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
+ return failure();
+ }
+ xla::XlaOp result =
+ xla::Reduce(ctx.builder, GetTuple(op.operands(), ctx),
+ GetTuple(op.init_values(), ctx), body,
+ Convert_broadcast_dimensions(op.dimensions()));
+ if (op.getNumResults() == 1) {
+ value_map[op.getResult(0)] = result;
+ } else {
+ for (auto item : llvm::enumerate(op.getResults())) {
+ value_map[item.value()] = xla::GetTupleElement(result, item.index());
+ }
+ }
+ return success();
+}
+
+LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
+ // Failure on purpose because `xla_hlo::ReturnOp` will be handled by
+ // special purpose logic in `ConvertToHloModule::Lower`.
+ return failure();
+}
+
+LogicalResult ExportXlaOp(ReverseOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
+ return failure();
+}
+
+} // namespace
+} // namespace xla_hlo
+} // namespace mlir
+
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
namespace mlir {
@@ -205,61 +414,13 @@
#undef ELEMENTS_ATTR_TO_LITERAL
}
-class ConvertToHloModule {
- public:
- using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
- using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
-
- explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
- bool always_return_tuple)
- : module_(module),
- module_builder_("main"),
- use_tuple_args_(use_tuple_args),
- always_return_tuple_(always_return_tuple) {}
-
- // Perform the lowering to XLA. This function returns failure if an error was
- // encountered.
- LogicalResult Run() {
- for (auto func : module_.getOps<FuncOp>()) {
- if (func.empty()) continue;
- if (failed(RunOnFunction(func))) return failure();
- }
- return success();
- }
-
- // Perform the lowering on a specific function. This function returns failure
- // if an error was encountered.
- LogicalResult RunOnFunction(mlir::FuncOp f);
-
- ::xla::HloModuleProto ConsumeMainProto() {
- return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
- .proto();
- }
-
- private:
- LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
- ConvertToHloModule::ValueLoweringMap* value_lowering);
-
- // The module being lowered.
- mlir::ModuleOp module_;
-
- // The top-level XlaBuilder.
- xla::XlaBuilder module_builder_;
-
- // Map between function and lowered computation.
- FunctionLoweringMap lowered_computation_;
-
- // Whether the entry function should take a single tuple as input.
- bool use_tuple_args_;
-
- // Whether to always return a tuple.
- bool always_return_tuple_;
-};
-
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, xla::XlaBuilder* builder,
- ConvertToHloModule::ValueLoweringMap* value_lowering) {
- if (succeeded(ExportXlaOperator(inst, value_lowering))) return success();
+ ConvertToHloModule::ValueLoweringMap* value_lowering,
+ xla::XlaComputation* result) {
+ if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
+ return success();
+ }
auto& value_map = *value_lowering;
ElementsAttr const_attr;
@@ -273,19 +434,19 @@
return success();
}
- if (auto ret = dyn_cast<mlir::ReturnOp>(inst)) {
+ if (isa<xla_hlo::ReturnOp>(inst) || isa<mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there are multiple
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;
- unsigned num_return_values = ret.getNumOperands();
+ unsigned num_return_values = inst->getNumOperands();
if (always_return_tuple_ || num_return_values > 1) {
std::vector<xla::XlaOp> returns(num_return_values);
- for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
- returns[i] = value_map[ret.getOperand(i)];
+ for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) {
+ returns[i] = value_map[inst->getOperand(i)];
}
return_value = xla::Tuple(builder, returns);
} else if (num_return_values == 1) {
- return_value = value_map[ret.getOperand(0)];
+ return_value = value_map[inst->getOperand(0)];
}
// Build the XlaComputation and check for failures.
@@ -295,10 +456,10 @@
inst->emitError(llvm::Twine(computation_or.status().error_message()));
return failure();
}
- auto f = inst->getParentOfType<mlir::FuncOp>();
- lowered_computation_[f] = std::move(computation_or.ValueOrDie());
+ *result = std::move(computation_or.ValueOrDie());
return success();
}
+
inst->emitError("unable to lower operation of type '" +
inst->getName().getStringRef().str() + '\'');
return failure();
@@ -316,20 +477,30 @@
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up;
+ xla::XlaComputation computation;
+ if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, &computation))) {
+ return failure();
+ }
+ lowered_computation_[f] = std::move(computation);
+ return success();
+}
+
+LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
+ Block* block, xla::XlaBuilder* builder, xla::XlaComputation* result) {
+ auto& bb = *block;
// Mapping from the Value to lowered XlaOp. The code below lowers in
// program order and will fail if an operand is unseen. This can be improved.
ValueLoweringMap lowering;
- auto& bb = f.front();
- // If using tuples as input, then there is only one input
- // parameter that is a tuple.
+ // If using tuples as input, then there is only one input parameter that is a
+ // tuple.
if (use_tuple_args_) {
std::vector<xla::Shape> arg_shapes;
arg_shapes.reserve(bb.getNumArguments());
for (auto& arg : bb.getArguments())
arg_shapes.push_back(xla::TypeToShape(arg->getType()));
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
- auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple");
+ auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple");
for (auto& it : llvm::enumerate(bb.getArguments())) {
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
}
@@ -339,16 +510,23 @@
auto num = it.index();
xla::Shape shape = xla::TypeToShape(arg->getType());
lowering[arg] =
- xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
+ xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
}
}
for (auto& inst : bb)
- if (failed(Lower(&inst, &builder, &lowering))) return failure();
+ if (failed(Lower(&inst, builder, &lowering, result))) return failure();
return success();
}
+LogicalResult ConvertToHloModule::LowerRegionAsComputation(
+ mlir::Region* region, xla::XlaComputation* func) {
+ std::unique_ptr<xla::XlaBuilder> builder =
+ module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
+ return LowerBasicBlockAsFunction(®ion->front(), builder.get(), func);
+}
+
} // namespace
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
index 00b7cd0..4a9555a 100644
--- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
+++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
@@ -49,8 +49,8 @@
static void BuildOperator(const Operator& op, raw_ostream* output) {
auto& os = *output;
- os << " auto& value_map = *value_lowering;\n"
- << " auto result = xla_op.getResult();\n";
+ os << " auto& value_map = *lowering_context.values;\n"
+ << " auto result = xla_op.getResult();\n";
// Build a conversion for each of the arguments.
int operand_number = 0;
@@ -61,36 +61,39 @@
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
// Handle a non-variadic operand.
if (!operand_cst->isVariadic()) {
- os << "auto xla_arg_" << index << " = value_map[*xla_op.getODSOperands("
- << operand_number++ << ").begin()];\n";
+ os << " auto xla_arg_" << index
+ << " = value_map[*xla_op.getODSOperands(" << operand_number++
+ << ").begin()];\n";
continue;
}
// Otherwise, this is a varidiac operand list.
- os << " std::vector<xla::XlaOp> xla_arg_" << index << ";"
- << " for (auto operand : xla_op.getODSOperands(" << operand_number++
- << "))\n xla_arg_" << index << ".push_back(value_map[operand]);\n";
+ os << " std::vector<xla::XlaOp> xla_arg_" << index << ";"
+ << " for (auto operand : xla_op.getODSOperands(" << operand_number++
+ << "))\n xla_arg_" << index
+ << ".push_back(value_map[operand]);\n";
continue;
}
// Otherwise, this is an attribute.
auto named_attr = arg.get<NamedAttribute*>();
- os << "auto xla_arg_" << index << " = " << GetDefaultAttrExport(*named_attr)
- << "(xla_op." << op.getArgName(index) << "());\n";
+ os << " auto xla_arg_" << index << " = "
+ << GetDefaultAttrExport(*named_attr) << "(xla_op."
+ << op.getArgName(index) << "());\n";
}
// Assumes that the client builder method names closely follow the op names
// in the dialect. For e.g., AddOp -> xla::Add method.
StringRef op_name = op.getCppClassName();
- os << " auto xla_result = xla::" << op_name.drop_back(2) << "(";
+ os << " auto xla_result = xla::" << op_name.drop_back(2) << "(";
// Emit each of the arguments.
interleaveComma(llvm::seq<int>(0, op.getNumArgs()), os,
[&](int i) { os << "Unwrap(xla_arg_" << i << ')'; });
os << ");\n";
- os << " value_map[result] = xla_result;\n";
- os << " return mlir::success();\n";
+ os << " value_map[result] = xla_result;\n";
+ os << " return mlir::success();\n";
}
// The function below has a non-constant reference as that is required by LLVM's
@@ -102,20 +105,23 @@
// Emit a function to generate an XLA operation for the operations with
// auto-generated builders.
os << "mlir::LogicalResult ExportXlaOperator(\n"
- "mlir::Operation* op, llvm::DenseMap<mlir::Value*, xla::XlaOp> "
- "*value_lowering) {\n";
+ "mlir::Operation* op, OpLoweringContext lowering_context) {\n";
// Retrieve all the definitions derived from HLO_Op and sort by record name.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
// Skip operations that have a custom exporter.
- if (def->getValueAsBit("hasCustomHLOConverter")) continue;
Operator op(def);
// Cast to the current operation and build the exporter.
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
<< op.getCppClassName() << ">(op)) {\n";
- BuildOperator(op, &os);
- os << "}\n";
+ if (def->getValueAsBit("hasCustomHLOConverter")) {
+ os << " return mlir::xla_hlo::ExportXlaOp(xla_op, "
+ "lowering_context);\n";
+ } else {
+ BuildOperator(op, &os);
+ }
+ os << " }\n";
}
os << " return mlir::failure();\n"
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
new file mode 100644
index 0000000..ac53ba9
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
@@ -0,0 +1,12 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
+ %result = "xla_hlo.broadcast_in_dim"(%arg0) {
+ broadcast_dimensions = dense<0> : tensor<1xi64>
+ } : (tensor<1xf32>) -> tensor<1x10xf32>
+ return %result : tensor<1x10xf32>
+}
+
+// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
+// CHECK: %[[ARG0]] = f32[1] parameter(0)
+// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
new file mode 100644
index 0000000..db16a22
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
@@ -0,0 +1,24 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
+ %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
+ ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
+ %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
+ return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
+}
+
+// CHECK: %[[REGION:region_[0-9]+]]
+// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
+// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
+// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
+// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
+
+// CHECK: ENTRY %main
+// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
+// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
+// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
+// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
+// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])