[mhlo] Implement DotGeneralOp::reifyReturnTypeShape and use it in linalg lowering

This makes the test more noisy, but also fixes bugs in the lowering when the
dimensions are permuted.

PiperOrigin-RevId: 446932742
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 0dd0bd3..6f54fb0 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1672,7 +1672,7 @@
   }];
 }
 
-def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]> {
+def HLO_DotGeneralOp: HLO_ShapedInterfaceOp<"dot_general", [NoSideEffect]> {
   let summary = "General Dot operator";
   let description = [{
     Performs general dot products between vectors, vector/matrix and
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 7ff8a7b..a7d174a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -914,6 +914,43 @@
   return success();
 }
 
+LogicalResult DotGeneralOp::reifyReturnTypeShapes(
+    OpBuilder& builder, ValueRange operands,
+    SmallVectorImpl<Value>& reifiedReturnShapes) {
+  auto lhs_type = lhs().getType().dyn_cast<ShapedType>();
+  auto rhs_type = rhs().getType().dyn_cast<ShapedType>();
+  if (!lhs_type || !rhs_type) {
+    return failure();
+  }
+
+  Adaptor adaptor(operands);
+  auto dim_numbers = dot_dimension_numbers();
+  SmallVector<Value> dimensions;
+  for (const int64_t lhs_dim : dim_numbers.getLhsBatchingDimensions()) {
+    dimensions.push_back(
+        builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhs_dim));
+  }
+
+  for (int64_t i = 0; i < lhs_type.getRank(); i++) {
+    if (!llvm::is_contained(dim_numbers.getLhsContractingDimensions(), i) &&
+        !llvm::is_contained(dim_numbers.getLhsBatchingDimensions(), i)) {
+      dimensions.push_back(
+          builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
+    }
+  }
+  for (int64_t i = 0; i < rhs_type.getRank(); i++) {
+    if (!llvm::is_contained(dim_numbers.getRhsContractingDimensions(), i) &&
+        !llvm::is_contained(dim_numbers.getRhsBatchingDimensions(), i)) {
+      dimensions.push_back(
+          builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
+    }
+  }
+
+  reifiedReturnShapes.push_back(
+      builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FftOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 50a8af9..ceb98e6 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -1605,18 +1605,6 @@
   }
 };
 
-SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
-    OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
-  SmallVector<Value, 8> dyn_shape;
-  if (result_type.isDynamicDim(0))
-    dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
-  if (result_type.isDynamicDim(1))
-    dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 1));
-  if (result_type.isDynamicDim(2))
-    dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 2));
-  return dyn_shape;
-}
-
 class DotGeneralBatchMatMulOpConversion
     : public OpConversionPattern<mhlo::DotGeneralOp> {
  public:
@@ -1653,11 +1641,10 @@
     Location loc = op.getLoc();
     auto output_type = op.getType().cast<ShapedType>();
     auto output_el_type = output_type.getElementType();
-    SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
-        rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
     auto zero_attr = rewriter.getZeroAttr(output_el_type);
     Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
-    auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
+    auto init_tensor =
+        GetInitTensorFor(rewriter, loc, output_type, op, adaptor.getOperands());
     Value zero_tensor =
         rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
     Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
@@ -2926,11 +2913,10 @@
 
     Location loc = op.getLoc();
     auto output_el_type = output_type.getElementType();
-    SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
-        rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
     auto zero_attr = rewriter.getZeroAttr(output_el_type);
     Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
-    auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
+    auto init_tensor =
+        GetInitTensorFor(rewriter, loc, output_type, op, adaptor.getOperands());
     Value zero_tensor =
         rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
     SmallVector<AffineMap, 3> indexing_maps;
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index eed9048..3b0fbf5 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -1809,6 +1809,13 @@
 // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
 // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
 // CHECK: linalg.batch_matmul
@@ -1839,6 +1846,13 @@
 // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
 // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
 // CHECK: linalg.batch_matmul
@@ -1868,6 +1882,13 @@
 // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
 // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
 // CHECK: linalg.batch_matmul
@@ -4262,12 +4283,19 @@
 // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 // CHECK: func @dot_general(
 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[D0]], %[[D1]], %[[D2]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[D0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]]
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]]
 // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]]
 // CHECK: linalg.generic