Adjust LeakyRelu and LeakyReluGrad lowering to work for dynamic shapes.
Before, the lowering used Broadcast, which only works for static shapes.
By using chlo::ConstantLikeOp, we can support fully dynamic (including
unranked) shapes.
PiperOrigin-RevId: 393084552
Change-Id: Ibc98e4522471b73584342f12c4432cb63bf19d57
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 036a267..0b6c28e 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1733,30 +1733,52 @@
// CHECK-LABEL: func @leaky_relu
func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} {
- // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
- // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%[[ALPHA]]) {broadcast_sizes = dense<[1, 4, 4, 3]> : tensor<4xi64>} : (tensor<f32>) -> tensor<1x4x4x3xf32>
- // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4x3xf32>
- // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[BCASTALPHA]] : tensor<1x4x4x3xf32>
- // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
- // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
- // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
+ // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+ // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+ // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32>
+ // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
+ // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
%0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
return %0 : tensor<1x4x4x3xf32>
}
+// CHECK-LABEL: func @leaky_relu_unranked
+func @leaky_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
+ // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<*xf32>
+ // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
+ %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// CHECK-LABEL: func @leaky_relu_grad
func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} {
- // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
- // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%0) {broadcast_sizes = dense<[1, 4, 4]> : tensor<3xi64>} : (tensor<f32>) -> tensor<1x4x4xf32>
- // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4xf32>
- // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[BCASTALPHA]] : tensor<1x4x4xf32>
- // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
- // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
- // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
+ // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+ // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+ // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32>
+ // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
+ // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
%0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
return %0 : tensor<1x4x4xf32>
}
+// CHECK-LABEL: func @leaky_relu_grad_unranked
+func @leaky_relu_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf.entry_function = {}} {
+ // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) {value = 2.000000e-01 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<*xf32>
+ // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ // CHECK-NEXT: return %[[RES]] : tensor<*xf32>
+ %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// CHECK-LABEL: func @softsign
func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> {
// CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32>
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index b433f04..e887b77 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -1860,27 +1860,16 @@
LogicalResult matchAndRewrite(TF::LeakyReluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- float alpha = op.alpha().convertToFloat();
Value features = op.features();
- auto featureType = features.getType().cast<RankedTensorType>();
- ArrayRef<int64_t> featureShape = featureType.getShape();
- Type eltType = featureType.getElementType();
+ auto featureType = features.getType();
- auto alphaVal = rewriter.create<mhlo::ConstOp>(
- loc, rewriter.getFloatAttr(eltType, alpha));
-
- // Broadcast `alpha` to match the shape of feature.
- auto featureShapeAttr = DenseIntElementsAttr::get(
- RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
- featureShape);
- auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
- loc, featureType, alphaVal, featureShapeAttr);
-
- Attribute zeroAttr = rewriter.getZeroAttr(featureType);
- Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
+ // Use ConstantLike for `alpha` to match the shape of feature.
+ auto alphaVal = chlo::getConstantLike(
+ rewriter, loc, op.alpha().convertToFloat(), features);
+ Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
Value leakyActivationVal = rewriter.create<mhlo::MulOp>(
- loc, features.getType(), features, broadcastAlphaVal);
+ loc, features.getType(), features, alphaVal);
StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
Value compareGtZero = rewriter.create<mhlo::CompareOp>(
@@ -1901,26 +1890,17 @@
LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- float alpha = op.alpha().convertToFloat();
Value gradients = op.gradients();
Value features = op.features();
- auto featureType = features.getType().cast<RankedTensorType>();
- ArrayRef<int64_t> featureShape = featureType.getShape();
- Type eltType = featureType.getElementType();
+ auto featureType = features.getType();
- auto alphaVal = rewriter.create<mhlo::ConstOp>(
- loc, rewriter.getFloatAttr(eltType, alpha));
- auto featureShapeAttr = DenseIntElementsAttr::get(
- RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
- featureShape);
- auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
- loc, featureType, alphaVal, featureShapeAttr);
-
- Attribute zeroAttr = rewriter.getZeroAttr(featureType);
- Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
+ // Use ConstantLike for `alpha` to match the shape of feature.
+ auto alphaVal = chlo::getConstantLike(
+ rewriter, loc, op.alpha().convertToFloat(), features);
+ Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
Value leakyGradientVal = rewriter.create<mhlo::MulOp>(
- loc, features.getType(), gradients, broadcastAlphaVal);
+ loc, features.getType(), gradients, alphaVal);
StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");