Add XLA lowering pattern for TensorFlow Softmax op

PiperOrigin-RevId: 267740332
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 1b020c0..3b94f7b 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -216,6 +216,64 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Softmax op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @simple_softmax
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
+func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  // CHECK: %[[NEG_INF:.*]] = constant dense<0xFF800000> : tensor<f32>
+  // CHECK: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<f32>
+
+  // Verify reduce op for max computation and its body.
+  // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]])
+  // CHECK:  xla_hlo.max
+  // CHECK: "xla_hlo.return"
+  // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
+
+  // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
+  // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]])
+
+  // Verify reduce op for summation and its body.
+  // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]])
+  // CHECK:  xla_hlo.add
+  // CHECK: "xla_hlo.return"
+  // CHECK: {dimensions = dense<1> : tensor<1xi64>}
+
+  // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
+  // return %[[RESULT]]
+
+  %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
+  return %0: tensor<2x3xf32>
+}
+
+// CHECK-LABEL: bf16_softmax
+func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> {
+  // Verify that conversion to f32 and then back to bf16 are introduced.
+
+  // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32>
+  // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16>
+
+  %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16>
+  return %0: tensor<2x3xbf16>
+}
+
+// CHECK-LABEL: rank4_softmax
+func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> {
+  // Verify that reduce op dimensions and broadcast dimensions are correct.
+
+  // CHECK: "xla_hlo.reduce"
+  // CHECK: dimensions = dense<3>
+
+  // CHECK: "xla_hlo.reduce"
+  // CHECK: dimensions = dense<3>
+
+  // CHECK: "xla_hlo.div"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
+  %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16>
+  return %0: tensor<2x3x4x5xf16>
+}
+
+//===----------------------------------------------------------------------===//
 // Unary op legalizations.
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 1e2b403..b715b52 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -19,8 +19,11 @@
 
 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
+#include "mlir/IR/Module.h"  // TF:local_config_mlir
 #include "mlir/IR/Operation.h"  // TF:local_config_mlir
 #include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
@@ -168,6 +171,34 @@
       b.getTensorType({minRank}, b.getIntegerType(64)), broadcastDimensions);
 }
 
+//===----------------------------------------------------------------------===//
+// Softmax op utilities.
+//===----------------------------------------------------------------------===//
+
+// Returns a 1-d i64 elements attribute populated with numbers from start to
+// end, excluding.
+static ElementsAttr GetI64AttrForSeq(int start, int end, Builder *builder) {
+  int size = end - start;
+
+  SmallVector<int64_t, 4> vals;
+  vals.resize(size);
+  std::iota(vals.begin(), vals.end(), start);
+
+  TensorType ty = builder->getTensorType({size}, builder->getIntegerType(64));
+  return DenseIntElementsAttr::get<int64_t>(ty, vals);
+}
+
+// Returns the type to use for accumulating the given type.
+static Type GetAccumulationType(Type ty) {
+  // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
+  // repeated floating point additions.
+  return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
+}
+
+//===----------------------------------------------------------------------===//
+// Op converters.
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace xla {
 namespace {
@@ -217,6 +248,101 @@
   }
 };
 
+// Converts Softmax op to HLO ops computing softmax with the following formula:
+//
+//     softmax = div(exp(logits), sum(exp(logits)))
+//
+// Sample result with 2-d f16 inputs with B batches of with N elements each.
+//
+//    // Subtract each element by their batches' max to improve numerical
+//    // stability.
+//    %neg_infinity = constant dense<0xFF800000> : tensor<f16>
+//    %max = "xla_hlo.reduce"(%input, %neg_infinity) ["xla_hlo.max"]
+//             {dimensions = 1}
+//           : (tensor<BxNxf16>, tensor<1xf16>) -> tensor<Bxf16>
+//    %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0}
+//            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
+//
+//    %exp = "xla_hlo.exp"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
+//
+//    // Cast to f32 to avoid precision loss in summation.
+//    %exp_f32 = "xla_hlo.convert"(%exp) : (tensor<BxNxbf16>) -> tensor<BxNxf32>
+//    %zero = constant dense<0.000000e+00> : tensor<f32>
+//    %sum = "xla_hlo.reduce"(%exp, %zero) ["xla_hlo.add"] {dimensions = 1}
+//            : (tensor<BxNxf32>, tensor<1xf32>) -> tensor<Bxf32>
+//
+//    %sum_f16 = "xla_hlo.convert"(%sum) : (tensor<BxNxbf32>) -> tensor<BxNxf16>
+//    %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0}
+//            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
+//
+class ConvertSoftmaxOp : public OpRewritePattern<TF::SoftmaxOp> {
+ public:
+  explicit ConvertSoftmaxOp(MLIRContext *context)
+      : OpRewritePattern<TF::SoftmaxOp>(context, 1) {}
+
+  PatternMatchResult matchAndRewrite(TF::SoftmaxOp op,
+                                     PatternRewriter &rewriter) const override {
+    Value *logits = op.logits();
+
+    // Softmax converter requires ranked type because the XLA reduce ops used
+    // while lowering requires dimensions attribute to reduce along.
+    RankedTensorType type = logits->getType().dyn_cast<RankedTensorType>();
+    if (!type) return matchFailure();
+    int rank = type.getRank();
+
+    // Note that the TensorFlow Softmax op verifies that the input rank is
+    // greater than or equal to one so both of the following sequences are
+    // valid.
+    ElementsAttr batch_dims = GetI64AttrForSeq(0, rank - 1, &rewriter);
+    ElementsAttr reduce_dim = GetI64AttrForSeq(rank - 1, rank, &rewriter);
+    Location loc = op.getLoc();
+
+    // Exponential of input values and then their sum can be very large here.
+    // Division with large denominator is numerically unstable. To improve
+    // numerical stability, subtract each batch with their max element so that
+    // the maximum input value is zero. It can be shown that softmax computed
+    // after adding or subtracting all inputs in a batch using a common value
+    // gives mathematically equivalent result.
+    Type element_type = type.getElementType();
+    ArrayRef<int64_t> reduce_shape = type.getShape().drop_back();
+    RankedTensorType reduce_out_type =
+        rewriter.getTensorType(reduce_shape, element_type);
+    auto init = GetMinValueForType(element_type, loc, &rewriter);
+    auto max_logits = rewriter.create<xla_hlo::ReduceOp>(
+        loc, reduce_out_type, ArrayRef<Value *>{logits, init}, reduce_dim);
+    BuildReduceBody<xla_hlo::MaxOp>(element_type, &max_logits.body(),
+                                    &rewriter);
+    auto shifted_logits = rewriter.create<xla_hlo::SubOp>(
+        loc, type, logits, max_logits.getResult(0), batch_dims);
+
+    // Exponentiate the inputs.
+    Value *exp = rewriter.create<xla_hlo::ExpOp>(loc, type, shifted_logits);
+
+    // Cast the exponentials to the appropriate accumulation type to avoid
+    // precision loss during summation.
+    Type sum_element_type = GetAccumulationType(element_type);
+    Type sum_type = rewriter.getTensorType(type.getShape(), sum_element_type);
+    auto casted_exp = rewriter.create<xla_hlo::ConvertOp>(loc, sum_type, exp);
+
+    // Compute summation of the exponentials.
+    init = rewriter.create<ConstantOp>(
+        loc, DenseElementsAttr::get(rewriter.getTensorType({}, element_type),
+                                    rewriter.getZeroAttr(element_type)));
+    Type sum_out_type = rewriter.getTensorType(reduce_shape, sum_element_type);
+    auto exp_sum = rewriter.create<xla_hlo::ReduceOp>(
+        loc, sum_out_type, ArrayRef<Value *>{casted_exp, init}, reduce_dim);
+    BuildReduceBody<xla_hlo::AddOp>(element_type, &exp_sum.body(), &rewriter);
+    Value *sum = exp_sum.getResult(0);
+
+    // Convert the summation result back to the original element type and divide
+    // exponentials by the summations.
+    sum = rewriter.create<xla_hlo::ConvertOp>(loc, reduce_out_type, sum);
+    rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op.getOperation(), op.getType(),
+                                                exp, sum, batch_dims);
+    return matchSuccess();
+  }
+};
+
 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
 }  // end anonymous namespace
 }  // end namespace xla
@@ -227,6 +353,7 @@
   OwningRewritePatternList patterns;
   xla::populateWithGenerated(op->getContext(), &patterns);
   patterns.insert<mlir::xla::ConvertMaxPoolOp>(op->getContext());
+  patterns.insert<mlir::xla::ConvertSoftmaxOp>(op->getContext());
 
   // Recursively applies rewrite patterns to nested operations.
   applyPatternsGreedily(op, patterns);