For ops with `hasCustomHLOConverter` set in tablegen, generate a call to
`ExportXlaOp(op, ...)`.

Also fill in lowerings for xla_hlo.reduce and xla_hlo.broadcast_in_dim as
examples.

PiperOrigin-RevId: 274641888
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 43d3779..45bf845 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -73,6 +73,10 @@
   return ConvertDenseIntAttr(permutation);
 }
 
+static std::vector<int64> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
+  return {values.begin(), values.end()};
+}
+
 // Converts the precision config array of strings attribute into the
 // corresponding XLA proto. All the strings are assumed to be valid names of the
 // Precision enum. This should have been checked in the op verify method.
@@ -170,6 +174,211 @@
   return value.convertToDouble();
 }
 
+namespace mlir {
+namespace {
+class ConvertToHloModule {
+ public:
+  using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
+  using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
+
+  explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
+                              bool always_return_tuple)
+      : module_(module),
+        module_builder_("main"),
+        use_tuple_args_(use_tuple_args),
+        always_return_tuple_(always_return_tuple) {}
+
+  // Perform the lowering to XLA. This function returns failure if an error was
+  // encountered.
+  LogicalResult Run() {
+    for (auto func : module_.getOps<FuncOp>()) {
+      if (func.empty()) continue;
+      if (failed(RunOnFunction(func))) return failure();
+    }
+    return success();
+  }
+
+  // Lower a specific function to HLO.
+  LogicalResult RunOnFunction(mlir::FuncOp f);
+
+  // Lower a `mlir::Region` to a `XlaComputation`
+  LogicalResult LowerRegionAsComputation(mlir::Region* region,
+                                         xla::XlaComputation* func);
+
+  // Lower a single `Block` to a `XlaComputation`
+  LogicalResult LowerBasicBlockAsFunction(Block* block,
+                                          xla::XlaBuilder* builder,
+                                          xla::XlaComputation* result);
+
+  ::xla::HloModuleProto ConsumeMainProto() {
+    return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
+        .proto();
+  }
+
+ private:
+  LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
+                      ConvertToHloModule::ValueLoweringMap* value_lowering,
+                      xla::XlaComputation* result);
+
+  // The module being lowered.
+  mlir::ModuleOp module_;
+
+  // The top-level XlaBuilder.
+  xla::XlaBuilder module_builder_;
+
+  // Map between function and lowered computation.
+  FunctionLoweringMap lowered_computation_;
+
+  // Whether the entry function should take a single tuple as input.
+  bool use_tuple_args_;
+
+  // Whether to always return a tuple.
+  bool always_return_tuple_;
+
+  // Unique suffix to give to the name of the next lowered region.
+  size_t region_id_ = 0;
+};
+
+}  // namespace
+}  // namespace mlir
+
+namespace {
+
+struct OpLoweringContext {
+  llvm::DenseMap<mlir::Value*, xla::XlaOp>* values;
+  mlir::ConvertToHloModule* converter;
+  xla::XlaBuilder* builder;
+};
+
+llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
+                                          OpLoweringContext ctx) {
+  llvm::SmallVector<xla::XlaOp, 4> ops;
+  for (mlir::Value* value : values) {
+    ops.push_back((*ctx.values)[value]);
+  }
+  return ops;
+}
+
+}  // namespace
+
+namespace mlir {
+namespace xla_hlo {
+namespace {
+
+LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
+  auto type = op.getType().dyn_cast<RankedTensorType>();
+  if (!type) return failure();
+  auto& value_map = *ctx.values;
+  value_map[op] =
+      BroadcastInDim(value_map[op.operand()], Convert_ArrayRef(type.getShape()),
+                     Convert_broadcast_dimensions(op.broadcast_dimensions()));
+  return success();
+}
+
+LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ConvOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(CopyOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(DynamicSliceOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(DynamicUpdateSliceOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(GetTupleElementOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { return failure(); }
+
+LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
+  auto& value_map = *ctx.values;
+  xla::XlaComputation body;
+  if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
+    return failure();
+  }
+  xla::XlaOp result =
+      xla::Reduce(ctx.builder, GetTuple(op.operands(), ctx),
+                  GetTuple(op.init_values(), ctx), body,
+                  Convert_broadcast_dimensions(op.dimensions()));
+  if (op.getNumResults() == 1) {
+    value_map[op.getResult(0)] = result;
+  } else {
+    for (auto item : llvm::enumerate(op.getResults())) {
+      value_map[item.value()] = xla::GetTupleElement(result, item.index());
+    }
+  }
+  return success();
+}
+
+LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
+  // Failure on purpose because `xla_hlo::ReturnOp` will be handled by
+  // special purpose logic in `ConvertToHloModule::Lower`.
+  return failure();
+}
+
+LogicalResult ExportXlaOp(ReverseOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
+  return failure();
+}
+
+}  // namespace
+}  // namespace xla_hlo
+}  // namespace mlir
+
 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
 
 namespace mlir {
@@ -205,61 +414,13 @@
 #undef ELEMENTS_ATTR_TO_LITERAL
 }
 
-class ConvertToHloModule {
- public:
-  using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
-  using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
-
-  explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
-                              bool always_return_tuple)
-      : module_(module),
-        module_builder_("main"),
-        use_tuple_args_(use_tuple_args),
-        always_return_tuple_(always_return_tuple) {}
-
-  // Perform the lowering to XLA. This function returns failure if an error was
-  // encountered.
-  LogicalResult Run() {
-    for (auto func : module_.getOps<FuncOp>()) {
-      if (func.empty()) continue;
-      if (failed(RunOnFunction(func))) return failure();
-    }
-    return success();
-  }
-
-  // Perform the lowering on a specific function. This function returns failure
-  // if an error was encountered.
-  LogicalResult RunOnFunction(mlir::FuncOp f);
-
-  ::xla::HloModuleProto ConsumeMainProto() {
-    return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
-        .proto();
-  }
-
- private:
-  LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
-                      ConvertToHloModule::ValueLoweringMap* value_lowering);
-
-  // The module being lowered.
-  mlir::ModuleOp module_;
-
-  // The top-level XlaBuilder.
-  xla::XlaBuilder module_builder_;
-
-  // Map between function and lowered computation.
-  FunctionLoweringMap lowered_computation_;
-
-  // Whether the entry function should take a single tuple as input.
-  bool use_tuple_args_;
-
-  // Whether to always return a tuple.
-  bool always_return_tuple_;
-};
-
 LogicalResult ConvertToHloModule::Lower(
     mlir::Operation* inst, xla::XlaBuilder* builder,
-    ConvertToHloModule::ValueLoweringMap* value_lowering) {
-  if (succeeded(ExportXlaOperator(inst, value_lowering))) return success();
+    ConvertToHloModule::ValueLoweringMap* value_lowering,
+    xla::XlaComputation* result) {
+  if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
+    return success();
+  }
 
   auto& value_map = *value_lowering;
   ElementsAttr const_attr;
@@ -273,19 +434,19 @@
     return success();
   }
 
-  if (auto ret = dyn_cast<mlir::ReturnOp>(inst)) {
+  if (isa<xla_hlo::ReturnOp>(inst) || isa<mlir::ReturnOp>(inst)) {
     // Construct the return value for the function. If there are multiple
     // values returned, then create a tuple, else return value directly.
     xla::XlaOp return_value;
-    unsigned num_return_values = ret.getNumOperands();
+    unsigned num_return_values = inst->getNumOperands();
     if (always_return_tuple_ || num_return_values > 1) {
       std::vector<xla::XlaOp> returns(num_return_values);
-      for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
-        returns[i] = value_map[ret.getOperand(i)];
+      for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) {
+        returns[i] = value_map[inst->getOperand(i)];
       }
       return_value = xla::Tuple(builder, returns);
     } else if (num_return_values == 1) {
-      return_value = value_map[ret.getOperand(0)];
+      return_value = value_map[inst->getOperand(0)];
     }
 
     // Build the XlaComputation and check for failures.
@@ -295,10 +456,10 @@
       inst->emitError(llvm::Twine(computation_or.status().error_message()));
       return failure();
     }
-    auto f = inst->getParentOfType<mlir::FuncOp>();
-    lowered_computation_[f] = std::move(computation_or.ValueOrDie());
+    *result = std::move(computation_or.ValueOrDie());
     return success();
   }
+
   inst->emitError("unable to lower operation of type '" +
                   inst->getName().getStringRef().str() + '\'');
   return failure();
@@ -316,20 +477,30 @@
     builder_up = module_builder_.CreateSubBuilder(f.getName().str());
   auto& builder = entry_function ? module_builder_ : *builder_up;
 
+  xla::XlaComputation computation;
+  if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, &computation))) {
+    return failure();
+  }
+  lowered_computation_[f] = std::move(computation);
+  return success();
+}
+
+LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
+    Block* block, xla::XlaBuilder* builder, xla::XlaComputation* result) {
+  auto& bb = *block;
   // Mapping from the Value to lowered XlaOp. The code below lowers in
   // program order and will fail if an operand is unseen. This can be improved.
   ValueLoweringMap lowering;
-  auto& bb = f.front();
 
-  // If using tuples as input, then there is only one input
-  // parameter that is a tuple.
+  // If using tuples as input, then there is only one input parameter that is a
+  // tuple.
   if (use_tuple_args_) {
     std::vector<xla::Shape> arg_shapes;
     arg_shapes.reserve(bb.getNumArguments());
     for (auto& arg : bb.getArguments())
       arg_shapes.push_back(xla::TypeToShape(arg->getType()));
     xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
-    auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple");
+    auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple");
     for (auto& it : llvm::enumerate(bb.getArguments())) {
       lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
     }
@@ -339,16 +510,23 @@
       auto num = it.index();
       xla::Shape shape = xla::TypeToShape(arg->getType());
       lowering[arg] =
-          xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
+          xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
     }
   }
 
   for (auto& inst : bb)
-    if (failed(Lower(&inst, &builder, &lowering))) return failure();
+    if (failed(Lower(&inst, builder, &lowering, result))) return failure();
 
   return success();
 }
 
+LogicalResult ConvertToHloModule::LowerRegionAsComputation(
+    mlir::Region* region, xla::XlaComputation* func) {
+  std::unique_ptr<xla::XlaBuilder> builder =
+      module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
+  return LowerBasicBlockAsFunction(&region->front(), builder.get(), func);
+}
+
 }  // namespace
 
 Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
index 00b7cd0..4a9555a 100644
--- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
+++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
@@ -49,8 +49,8 @@
 
 static void BuildOperator(const Operator& op, raw_ostream* output) {
   auto& os = *output;
-  os << "  auto& value_map = *value_lowering;\n"
-     << "  auto result = xla_op.getResult();\n";
+  os << "    auto& value_map = *lowering_context.values;\n"
+     << "    auto result = xla_op.getResult();\n";
 
   // Build a conversion for each of the arguments.
   int operand_number = 0;
@@ -61,36 +61,39 @@
     if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
       // Handle a non-variadic operand.
       if (!operand_cst->isVariadic()) {
-        os << "auto xla_arg_" << index << " = value_map[*xla_op.getODSOperands("
-           << operand_number++ << ").begin()];\n";
+        os << "    auto xla_arg_" << index
+           << " = value_map[*xla_op.getODSOperands(" << operand_number++
+           << ").begin()];\n";
         continue;
       }
 
       // Otherwise, this is a varidiac operand list.
-      os << "  std::vector<xla::XlaOp> xla_arg_" << index << ";"
-         << "  for (auto operand : xla_op.getODSOperands(" << operand_number++
-         << "))\n    xla_arg_" << index << ".push_back(value_map[operand]);\n";
+      os << "    std::vector<xla::XlaOp> xla_arg_" << index << ";"
+         << "    for (auto operand : xla_op.getODSOperands(" << operand_number++
+         << "))\n      xla_arg_" << index
+         << ".push_back(value_map[operand]);\n";
       continue;
     }
 
     // Otherwise, this is an attribute.
     auto named_attr = arg.get<NamedAttribute*>();
-    os << "auto xla_arg_" << index << " = " << GetDefaultAttrExport(*named_attr)
-       << "(xla_op." << op.getArgName(index) << "());\n";
+    os << "    auto xla_arg_" << index << " = "
+       << GetDefaultAttrExport(*named_attr) << "(xla_op."
+       << op.getArgName(index) << "());\n";
   }
 
   // Assumes that the client builder method names closely follow the op names
   // in the dialect. For e.g., AddOp -> xla::Add method.
   StringRef op_name = op.getCppClassName();
-  os << "  auto xla_result = xla::" << op_name.drop_back(2) << "(";
+  os << "    auto xla_result = xla::" << op_name.drop_back(2) << "(";
 
   // Emit each of the arguments.
   interleaveComma(llvm::seq<int>(0, op.getNumArgs()), os,
                   [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; });
   os << ");\n";
 
-  os << "  value_map[result] = xla_result;\n";
-  os << "  return mlir::success();\n";
+  os << "    value_map[result] = xla_result;\n";
+  os << "    return mlir::success();\n";
 }
 
 // The function below has a non-constant reference as that is required by LLVM's
@@ -102,20 +105,23 @@
   // Emit a function to generate an XLA operation for the operations with
   // auto-generated builders.
   os << "mlir::LogicalResult ExportXlaOperator(\n"
-        "mlir::Operation* op, llvm::DenseMap<mlir::Value*, xla::XlaOp> "
-        "*value_lowering) {\n";
+        "mlir::Operation* op, OpLoweringContext lowering_context) {\n";
 
   // Retrieve all the definitions derived from HLO_Op and sort by record name.
   for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
     // Skip operations that have a custom exporter.
-    if (def->getValueAsBit("hasCustomHLOConverter")) continue;
     Operator op(def);
 
     // Cast to the current operation and build the exporter.
     os << "  if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
        << op.getCppClassName() << ">(op)) {\n";
-    BuildOperator(op, &os);
-    os << "}\n";
+    if (def->getValueAsBit("hasCustomHLOConverter")) {
+      os << "    return mlir::xla_hlo::ExportXlaOp(xla_op, "
+            "lowering_context);\n";
+    } else {
+      BuildOperator(op, &os);
+    }
+    os << "  }\n";
   }
 
   os << "  return mlir::failure();\n"
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
new file mode 100644
index 0000000..ac53ba9
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.mlir
@@ -0,0 +1,12 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
+  %result = "xla_hlo.broadcast_in_dim"(%arg0) {
+    broadcast_dimensions = dense<0> : tensor<1xi64>
+  } : (tensor<1xf32>) -> tensor<1x10xf32>
+  return %result : tensor<1x10xf32>
+}
+
+// CHECK: ENTRY %main.3 ([[ARG0:.*]]: f32[1]) -> f32[1,10] {
+// CHECK:  %[[ARG0]] = f32[1] parameter(0)
+// CHECK:  ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] %[[ARG0]]), dimensions={0}
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
new file mode 100644
index 0000000..db16a22
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/reduce.mlir
@@ -0,0 +1,24 @@
+// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
+  %result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
+    ^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>):   // no predecessors
+      %fmax = "xla_hlo.max"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      %imax = "xla_hlo.max"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      "xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
+    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
+  return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
+}
+
+// CHECK: %[[REGION:region_[0-9]+]]
+// CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[])
+// CHECK:  %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]])
+// CHECK:  %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]])
+// CHECK:  ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]])
+
+// CHECK: ENTRY %main
+// CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG0:.*]]: s32[1,10], [[ARG0:.*]]: f32[], [[ARG0:.*]]: s32[]) -> (f32[1], s32[1])
+// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %Arg_0.1, s32[1,10] %Arg_1.2, f32[] %Arg_2.3, s32[] %Arg_3.4), dimensions={1}, to_apply=%[[REGION]]
+// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0
+// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1
+// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]])