[mhlo] Implement DotGeneralOp::reifyReturnTypeShape and use it in linalg lowering
This makes the test more noisy, but also fixes bugs in the lowering when the
dimensions are permuted.
PiperOrigin-RevId: 446932742
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 0dd0bd3..6f54fb0 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1672,7 +1672,7 @@
}];
}
-def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]> {
+def HLO_DotGeneralOp: HLO_ShapedInterfaceOp<"dot_general", [NoSideEffect]> {
let summary = "General Dot operator";
let description = [{
Performs general dot products between vectors, vector/matrix and
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 7ff8a7b..a7d174a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -914,6 +914,43 @@
return success();
}
+LogicalResult DotGeneralOp::reifyReturnTypeShapes(
+ OpBuilder& builder, ValueRange operands,
+ SmallVectorImpl<Value>& reifiedReturnShapes) {
+ auto lhs_type = lhs().getType().dyn_cast<ShapedType>();
+ auto rhs_type = rhs().getType().dyn_cast<ShapedType>();
+ if (!lhs_type || !rhs_type) {
+ return failure();
+ }
+
+ Adaptor adaptor(operands);
+ auto dim_numbers = dot_dimension_numbers();
+ SmallVector<Value> dimensions;
+ for (const int64_t lhs_dim : dim_numbers.getLhsBatchingDimensions()) {
+ dimensions.push_back(
+ builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhs_dim));
+ }
+
+ for (int64_t i = 0; i < lhs_type.getRank(); i++) {
+ if (!llvm::is_contained(dim_numbers.getLhsContractingDimensions(), i) &&
+ !llvm::is_contained(dim_numbers.getLhsBatchingDimensions(), i)) {
+ dimensions.push_back(
+ builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
+ }
+ }
+ for (int64_t i = 0; i < rhs_type.getRank(); i++) {
+ if (!llvm::is_contained(dim_numbers.getRhsContractingDimensions(), i) &&
+ !llvm::is_contained(dim_numbers.getRhsBatchingDimensions(), i)) {
+ dimensions.push_back(
+ builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
+ }
+ }
+
+ reifiedReturnShapes.push_back(
+ builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 50a8af9..ceb98e6 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -1605,18 +1605,6 @@
}
};
-SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
- OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
- SmallVector<Value, 8> dyn_shape;
- if (result_type.isDynamicDim(0))
- dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
- if (result_type.isDynamicDim(1))
- dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 1));
- if (result_type.isDynamicDim(2))
- dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 2));
- return dyn_shape;
-}
-
class DotGeneralBatchMatMulOpConversion
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
@@ -1653,11 +1641,10 @@
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
- SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
- rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
- auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
+ auto init_tensor =
+ GetInitTensorFor(rewriter, loc, output_type, op, adaptor.getOperands());
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
@@ -2926,11 +2913,10 @@
Location loc = op.getLoc();
auto output_el_type = output_type.getElementType();
- SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
- rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
- auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
+ auto init_tensor =
+ GetInitTensorFor(rewriter, loc, output_type, op, adaptor.getOperands());
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
SmallVector<AffineMap, 3> indexing_maps;
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index eed9048..3b0fbf5 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -1809,6 +1809,13 @@
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.batch_matmul
@@ -1839,6 +1846,13 @@
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.batch_matmul
@@ -1868,6 +1882,13 @@
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.batch_matmul
@@ -4262,12 +4283,19 @@
// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: func @dot_general(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
// CHECK: linalg.generic