Handle negative exponents for lowering of hlo.pow

PiperOrigin-RevId: 352382812
Change-Id: I36c3924a2cc77a49c53ecbb4225f3bedd14e85ba
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
index bc163a6..636cd8c 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h
@@ -556,23 +556,53 @@
   assert(result_type.isa<::mlir::IntegerType>() &&
          "only float and integer `pow` is supported right now");
 
-  // There is no powi, so lower to a simple product. Note that HLO does not
-  // define semantics of negative exponents.
-  Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
+  // There is no powi, so lower to a simple product.
+  Value neg_one =
+      b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1));
+  Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0));
+  Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
+  Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2));
 
   Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
   Value upperBound =
       b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
   Value step = b->create<ConstantIndexOp>(loc, 1);
-  return b
-      ->create<scf::ForOp>(
-          loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
-          [&](OpBuilder& b, Location l, Value v, ValueRange iters) {
-            Value prod =
-                b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
-            b.create<scf::YieldOp>(l, prod);
-          })
-      .getResult(0);
+  Value for_result =
+      b->create<scf::ForOp>(
+           loc, lowerBound, upperBound, step, llvm::makeArrayRef(one),
+           [&](OpBuilder& b, Location l, Value v, ValueRange iters) {
+             Value prod =
+                 b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
+             b.create<scf::YieldOp>(l, prod);
+           })
+          .getResult(0);
+
+  Value rhs_is_even =
+      b->create<CmpIOp>(loc, CmpIPredicate::eq,
+                        b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero);
+  Value rhs_is_negative =
+      b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero);
+  Value lhs_is_one =
+      b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one);
+  Value lhs_is_neg_one =
+      b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one);
+
+  // The for_result is correct when the rhs is non-negative. When rhs is
+  // negative, we return 0 for integer, with the exception of lhs values of 1
+  // and -1 which have integer results for negative exponents. Specifically, the
+  // calulation is the following:
+  //
+  // - Return for_result if the rhs is not negative.
+  // - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
+  // - Return 1 if lhs is 1.
+  // - Else return 0.
+  Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero);
+  Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>(
+      loc, lhs_is_neg_one,
+      b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one),
+      if_lhs_is_one);
+  return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one,
+                                     for_result);
 }
 
 template <>
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index 6f59bf2..298ce7b 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -792,17 +792,26 @@
 // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @integer_pow
 func @integer_pow(%lhs: tensor<2x2xi32>,
-                %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
+                  %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
                     // CHECK: linalg.generic
   // CHECK: ^{{[a-z0-9_]*}}
   // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
   // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
   // CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
-  // CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
+  // CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
   // CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
   // CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
   //   CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
   //   CHECK: scf.yield %[[ACCUM]]
+  // CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2
+  // CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0
+  // CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0
+  // CHECK: %[[LHS_ONE:.*]] = cmpi eq, %[[ARG0]], %c1
+  // CHECK: %[[LHS_NEG_ONE:.*]] = cmpi eq, %[[ARG0]], %c-1
+  // CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0
+  // CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1
+  // CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
+  // CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]
   // CHECK: linalg.yield %[[RESULT]]
   %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
                                    tensor<2x2xi32>) -> tensor<2x2xi32>