[MLIR][HLO] Add cast when needed to simplify `shape.bcast` ops
PiperOrigin-RevId: 437832615
diff --git a/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc b/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
index 03fd24a..008dc2a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
@@ -118,8 +118,19 @@
return rewriter.create<tensor::ExtractOp>(loc, operand, operand_dim)
.getResult();
}));
- rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
- op, op->getResultTypes().front(), elements);
+ Type index_ty = rewriter.getIndexType();
+ Type concrete_result_ty = RankedTensorType::get(
+ {static_cast<int64_t>(elements.size())}, index_ty);
+ Value result = rewriter.create<tensor::FromElementsOp>(
+ loc, concrete_result_ty, elements);
+
+ // Insert cast, if needed.
+ Type expected_ty = op.getResult().getType();
+ if (result.getType() != expected_ty) {
+ result = rewriter.create<tensor::CastOp>(loc, expected_ty, result);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
index 7aa2212..e6ba7e5 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
@@ -689,3 +689,31 @@
-> tensor<1xindex>
func.return %2 : tensor<1xindex>
}
+
+// -----
+
+// CHECK-LABEL: @broadcast_w_dyn_ty
+// CHECK-SAME: %[[ARG:.*]]: tensor<1xindex>
+func.func @broadcast_w_dyn_ty(%arg0: tensor<1xindex>) -> tensor<?xindex>{
+ // CHECK: %[[C0:.*]] = arith.constant 0
+ // CHECK: %[[D0:.*]] = tensor.extract %[[ARG]][%[[C0]]]
+ // CHECK: %[[UNCAST:.*]] = tensor.from_elements %[[D0]]
+ // CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<1xindex> to tensor<?xindex>
+ // CHECK: return %[[CAST]]
+ %0 = shape.broadcast %arg0, %arg0
+ : tensor<1xindex>, tensor<1xindex> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_scalar_w_dyn_ty
+// CHECK-SAME: %[[ARG:.*]]: tensor<0xindex>
+func.func @broadcast_scalar_w_dyn_ty(%arg0: tensor<0xindex>) -> tensor<?xindex>{
+ // CHECK: %[[UNCAST:.*]] = arith.constant dense<> : tensor<0xindex>
+ // CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<0xindex> to tensor<?xindex>
+ // CHECK: return %[[CAST]]
+ %0 = shape.broadcast %arg0, %arg0
+ : tensor<0xindex>, tensor<0xindex> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
+}