Adjust LeakyRelu and LeakyReluGrad lowering to work for dynamic shapes.

Before, the lowering used Broadcast, which only works for static shapes.
By using chlo::ConstantLikeOp, we can support fully dynamic (including
unranked) shapes.

PiperOrigin-RevId: 393084552
Change-Id: Ibc98e4522471b73584342f12c4432cb63bf19d57
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 036a267..0b6c28e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1733,30 +1733,52 @@
 
 // CHECK-LABEL: func @leaky_relu
 func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} {
-   // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
-   // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%[[ALPHA]]) {broadcast_sizes = dense<[1, 4, 4, 3]> : tensor<4xi64>} : (tensor<f32>) -> tensor<1x4x4x3xf32>
-   // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4x3xf32>
-   // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[BCASTALPHA]] : tensor<1x4x4x3xf32>
-   // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
-   // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
-   // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
+    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+    // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32>
+    // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
+    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+    // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
     %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
     return %0 : tensor<1x4x4x3xf32>
 }
 
+// CHECK-LABEL: func @leaky_relu_unranked
+func @leaky_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
+    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<*xf32>
+    // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
+    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
+    %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>) -> tensor<*xf32>
+    return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL: func @leaky_relu_grad
 func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} {
-   // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
-   // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%0) {broadcast_sizes = dense<[1, 4, 4]> : tensor<3xi64>} : (tensor<f32>) -> tensor<1x4x4xf32>
-   // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4xf32>
-   // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[BCASTALPHA]] : tensor<1x4x4xf32>
-   // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
-   // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
-   // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
+    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+    // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32>
+    // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
+    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+    // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
     %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
     return %0 : tensor<1x4x4xf32>
 }
 
+// CHECK-LABEL: func @leaky_relu_grad_unranked
+func @leaky_relu_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
+    // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<*xf32>
+    // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
+    // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+    // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
+    %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+    return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL: func @softsign
 func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> {
     // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32>
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index b433f04..e887b77 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -1860,27 +1860,16 @@
   LogicalResult matchAndRewrite(TF::LeakyReluOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    float alpha = op.alpha().convertToFloat();
     Value features = op.features();
-    auto featureType = features.getType().cast<RankedTensorType>();
-    ArrayRef<int64_t> featureShape = featureType.getShape();
-    Type eltType = featureType.getElementType();
+    auto featureType = features.getType();
 
-    auto alphaVal = rewriter.create<mhlo::ConstOp>(
-        loc, rewriter.getFloatAttr(eltType, alpha));
-
-    // Broadcast `alpha` to match the shape of feature.
-    auto featureShapeAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
-        featureShape);
-    auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
-        loc, featureType, alphaVal, featureShapeAttr);
-
-    Attribute zeroAttr = rewriter.getZeroAttr(featureType);
-    Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
+    // Use ConstantLike for `alpha` to match the shape of feature.
+    auto alphaVal = chlo::getConstantLike(
+        rewriter, loc, op.alpha().convertToFloat(), features);
+    Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
 
     Value leakyActivationVal = rewriter.create<mhlo::MulOp>(
-        loc, features.getType(), features, broadcastAlphaVal);
+        loc, features.getType(), features, alphaVal);
 
     StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
     Value compareGtZero = rewriter.create<mhlo::CompareOp>(
@@ -1901,26 +1890,17 @@
   LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    float alpha = op.alpha().convertToFloat();
     Value gradients = op.gradients();
     Value features = op.features();
-    auto featureType = features.getType().cast<RankedTensorType>();
-    ArrayRef<int64_t> featureShape = featureType.getShape();
-    Type eltType = featureType.getElementType();
+    auto featureType = features.getType();
 
-    auto alphaVal = rewriter.create<mhlo::ConstOp>(
-        loc, rewriter.getFloatAttr(eltType, alpha));
-    auto featureShapeAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
-        featureShape);
-    auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
-        loc, featureType, alphaVal, featureShapeAttr);
-
-    Attribute zeroAttr = rewriter.getZeroAttr(featureType);
-    Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
+    // Use ConstantLike for `alpha` to match the shape of feature.
+    auto alphaVal = chlo::getConstantLike(
+        rewriter, loc, op.alpha().convertToFloat(), features);
+    Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
 
     Value leakyGradientVal = rewriter.create<mhlo::MulOp>(
-        loc, features.getType(), gradients, broadcastAlphaVal);
+        loc, features.getType(), gradients, alphaVal);
 
     StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");