Support legalizing Einsum op with ellipsis in both lhs and rhs

The broadcasting semantics is handled by TF::BatchMatMulV2Op.

PiperOrigin-RevId: 383187306
Change-Id: I3f997b15c69e6228f85d594c484f028f6095ec98
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
index 973e46d..50332ef 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
@@ -200,3 +200,29 @@
 // CHECK-LABEL: einsum_ellipsis
 // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32>
 }
+
+func @einsum_ellipsis_in_both_sides(%arg0: tensor<1x11x19xf32>, %arg1: tensor<7x11x13x19xf32>) -> tensor<7x11x13xf32> {
+  %0 = "tf.Einsum"(%arg0, %arg1) {device = "", equation = "...IJ,...INJ->...IN"} : (tensor<1x11x19xf32>, tensor<7x11x13x19xf32>) -> tensor<7x11x13xf32>
+  return %0 : tensor<7x11x13xf32>
+  // CHECK-LABEL: einsum_ellipsis_in_both_sides
+  // CHECK: %[[cst:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
+  // CHECK: %[[cst_1:.*]] = constant dense<[1, 11, 1, 19]> : tensor<4xi64>
+  // CHECK: %[[cst_2:.*]] = constant dense<[7, 11, 13]> : tensor<3xi64>
+  // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<7x11x13x19xf32>, tensor<4xi32>) -> tensor<7x11x19x13xf32>
+  // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<1x11x19xf32>, tensor<4xi64>) -> tensor<1x11x1x19xf32>
+  // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<1x11x1x19xf32>, tensor<7x11x19x13xf32>) -> tensor<7x11x1x13xf32>
+  // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<7x11x1x13xf32>, tensor<3xi64>) -> tensor<7x11x13xf32>
+  // CHECK: return %[[v3]] : tensor<7x11x13xf32>
+}
+
+func @einsum_ellipsis_with_broadcast(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<4x2x5xf32> {
+  %0 = "tf.Einsum"(%arg0, %arg1) {device = "", equation = "...ij,j...->i..."} : (tensor<5x4x3xf32>, tensor<3x2x1xf32>) -> tensor<4x2x5xf32>
+  return %0 : tensor<4x2x5xf32>
+  // CHECK-LABEL: einsum_ellipsis_with_broadcast
+  // CHECK: %[[cst:.*]] = constant dense<[2, 0, 1]> : tensor<3xi32>
+  // CHECK: %[[cst_1:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
+  // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<3x2x1xf32>, tensor<3xi32>) -> tensor<1x3x2xf32>
+  // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<5x4x3xf32>, tensor<1x3x2xf32>) -> tensor<5x4x2xf32>
+  // CHECK: %[[v2:.*]] = "tf.Transpose"(%[[v1]], %[[cst_1]]) : (tensor<5x4x2xf32>, tensor<3xi32>) -> tensor<4x2x5xf32>
+  // CHECK: return %[[v2]] : tensor<4x2x5xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
index a506373..0d87d19 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
@@ -21,6 +21,7 @@
 #include <cstdint>
 #include <string>
 #include <tuple>
+#include <utility>
 
 #include "absl/memory/memory.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -35,6 +36,7 @@
 #include "llvm/Support/Regex.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/Dialect/Traits.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
@@ -123,7 +125,6 @@
     labels.insert('a' + i);
     labels.insert('A' + i);
   }
-  bool ellipsis_observed = false;
 
   auto is_start_of_ellipsis = [](StringRef equation, int start_index) {
     if (equation.size() < (start_index + 3)) return false;
@@ -141,14 +142,12 @@
       ++lhs_count;
     } else if (label == '.') {
       if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
-      ellipsis_observed = true;
       i += 2;
     } else {
       // Unsupported character in the equation.
       return llvm::None;
     }
   }
-
   *lhs_named_label_count = lhs_count;
 
   int rhs_count = 0;
@@ -159,11 +158,8 @@
       labels.remove(label);
       ++rhs_count;
     } else if (label == '.') {
-      if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
-      // We do not support both lhs & rhs have ellipsis for now.
-      if (ellipsis_observed) return llvm::None;
+      if (!is_start_of_ellipsis(rhs, i)) return llvm::None;
       i += 2;
-      ellipsis_observed = true;
     } else {
       // Unsupported character in the equation.
       return llvm::None;
@@ -178,11 +174,11 @@
 // For example, if we have GenerateLabels(2, {'b', 'c', 'd'}) for "...xy"
 // We will have "dcxy" for the ellipsis expression since it's rank 4,
 // we will have dcbxy if it's rank 5.
-std::string GenerateLabels(int count, llvm::SetVector<char>* available_labels) {
-  std::string new_labels;
-  new_labels.reserve(count);
+std::string GenerateLabels(int count,
+                           const llvm::SetVector<char>& available_labels) {
+  std::string new_labels(count, 0);
   for (int i = 0; i < count; ++i) {
-    new_labels.push_back(available_labels->pop_back_val());
+    new_labels[count - 1 - i] = available_labels[i];
   }
 
   return new_labels;
@@ -191,7 +187,7 @@
 std::tuple<std::string, std::string, std::string> FlattenEllipsis(
     llvm::StringRef lhs, int lhs_named_label_count, llvm::StringRef rhs,
     int rhs_named_label_count, llvm::StringRef out, RankedTensorType lhs_ty,
-    RankedTensorType rhs_ty, llvm::SetVector<char>* available_labels) {
+    RankedTensorType rhs_ty, const llvm::SetVector<char>& available_labels) {
   std::string new_labels;
   std::string new_lhs;
   for (int i = 0; i < lhs.size(); ++i) {
@@ -208,7 +204,7 @@
     }
   }
 
-  std::string new_rhs;
+  std::string new_rhs, new_rhs_labels;
   for (int i = 0; i < rhs.size(); ++i) {
     const char label = rhs[i];
     if (std::isalpha(label)) {
@@ -216,10 +212,13 @@
     } else {
       // Encounter ellipsis: generate unnamed labels then insert to the new
       // labels.
-      new_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
-                                  available_labels);
-      new_rhs.append(new_labels);
+      new_rhs_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
+                                      available_labels);
+      new_rhs.append(new_rhs_labels);
       i += 2;
+      if (new_rhs_labels.size() > new_labels.size()) {
+        new_labels = new_rhs_labels;
+      }
     }
   }
 
@@ -254,10 +253,6 @@
   if (lhs.empty() || rhs.empty()) return llvm::None;
 
   // Try to flatten the "..." if possible.
-  // Currently we only support either lhs or the rhs has "..." but not both.
-  // Both usually will require a broadcasting semantics which is not supported
-  // by the batch_matmul.
-  // TODO(b/181244617): Consider handling the broadcasting scenario as well.
   int lhs_named_label, rhs_named_label;
   auto avaiable_labels =
       GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label);
@@ -265,7 +260,7 @@
 
   auto flattended_labels =
       FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty,
-                      rhs_ty, &avaiable_labels.getValue());
+                      rhs_ty, avaiable_labels.getValue());
 
   lhs = std::get<0>(flattended_labels);
   rhs = std::get<1>(flattended_labels);
@@ -383,27 +378,35 @@
 // B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
 LogicalResult reshapeForBatchMatmul(const Location& loc,
                                     EinsumDimensionNumbers& dnums, Value* lhs,
-                                    Value* rhs, std::vector<int64_t>* out_shape,
+                                    Value* rhs,
+                                    SmallVectorImpl<int64_t>* out_shape,
                                     PatternRewriter* rewriter) {
   RankedTensorType lhs_type = lhs->getType().cast<RankedTensorType>();
   RankedTensorType rhs_type = rhs->getType().cast<RankedTensorType>();
 
+  // Labels exist in all lhs, rhs and output are the batch labels B0,...,Bn.
   std::vector<int64_t> lhs_shape;
   std::vector<int64_t> rhs_shape;
   lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1);
   rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2);
   for (auto i : dnums.lhs_rhs_out) {
-    int64_t b = lhs_type.getShape()[std::get<0>(i)];
-    lhs_shape.push_back(b);
-    rhs_shape.push_back(b);
-    out_shape->push_back(b);
+    int64_t b1 = lhs_type.getShape()[std::get<0>(i)];
+    lhs_shape.push_back(b1);
+    int64_t b2 = rhs_type.getShape()[std::get<1>(i)];
+    rhs_shape.push_back(b2);
+  }
+  if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, *out_shape)) {
+    return failure();
   }
 
+  // Calculates dimension for the label L from L0,...,Ln in lhs.
   if (dnums.lhs_out.empty()) {
     lhs_shape.push_back(1);
     out_shape->push_back(1);
     dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1);
   } else if (dnums.lhs_rhs_out.empty()) {
+    // If there is not batch labels B0,...,Bn, it is safe to use L0,...,Ln as
+    // the batch labels in lhs, the rhs will be broadcasted.
     for (auto i : dnums.lhs_out) {
       int64_t b = lhs_type.getShape()[std::get<0>(i)];
       lhs_shape.push_back(b);
@@ -418,14 +421,18 @@
     out_shape->push_back(lhs_out_size);
   }
 
-  int64_t lhs_rhs_size = 1;
+  // Calculates dimension for the common label C from labels C0,...,Cn that
+  // exist in both lhs and rhs.
+  int64_t lhs_size = 1, rhs_size = 1;
   for (auto i : dnums.lhs_rhs) {
-    lhs_rhs_size *= lhs_type.getShape()[std::get<0>(i)];
+    lhs_size *= lhs_type.getShape()[std::get<0>(i)];
+    rhs_size *= rhs_type.getShape()[std::get<1>(i)];
   }
-  lhs_shape.push_back(lhs_rhs_size);
-  rhs_shape.push_back(lhs_rhs_size);
+  lhs_shape.push_back(lhs_size);
+  rhs_shape.push_back(rhs_size);
 
-  int64_t rhs_size = 1;
+  // Calculates dimension for the label R from R0,...,Rn in rhs.
+  rhs_size = 1;
   for (auto i : dnums.rhs_out) {
     rhs_size *= rhs_type.getShape()[std::get<0>(i)];
   }
@@ -469,7 +476,7 @@
                                      &out_transpose, &rewriter)))
     return failure();
 
-  std::vector<int64_t> matmul_shape;
+  llvm::SmallVector<int64_t, 4> matmul_shape;
   if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
                                    &matmul_shape, &rewriter)))
     return failure();
diff --git a/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py b/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
index 2feb575..0e4356b 100644
--- a/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
+++ b/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
@@ -32,17 +32,25 @@
   test_parameters = [
       {
           "dtype": [tf.float32],
-          "shapes": [((3, 4, 5), (3, 5, 6), "ijk,ikm->ijm"),
-                     ((3, 4, 5), (5, 6), "ijk,km->ijm"),
-                     ((2, 5, 7), (5, 2), "LBH,BL->BH"),
-                     ((2, 5, 7), (5, 3, 2), "LBH,BKL->BKH"),
-                     ((2, 5, 7, 3), (2, 4, 7, 3), "BFNH,BTNH->BNFT"),
-                     ((2, 5, 7, 3), (7, 3, 4), "BFND,NDH->BFH"),
-                     ((3, 4, 5), (5, 6, 2), "BFD,DNH->BFNH"),
-                     ((7, 11, 13), (7, 11, 13, 5), "BIN,BINJ->BIJ"),
-                     ((7, 11, 19), (7, 11, 13, 19), "BIJ,BINJ->BIN"),
-                     ((5, 13, 3, 11), (5, 11, 13, 8), "ACBE,AECD->ABCD"),
-                     ((5, 11, 7, 3), (5, 8, 7, 3), "AECD,ABCD->ACBE")],
+          "shapes": [
+              ((3, 4, 5), (3, 5, 6), "ijk,ikm->ijm"),
+              ((3, 4, 5), (5, 6), "ijk,km->ijm"),
+              ((2, 5, 7), (5, 2), "LBH,BL->BH"),
+              ((2, 5, 7), (5, 3, 2), "LBH,BKL->BKH"),
+              ((2, 5, 7, 3), (2, 4, 7, 3), "BFNH,BTNH->BNFT"),
+              ((2, 5, 7, 3), (7, 3, 4), "BFND,NDH->BFH"),
+              ((3, 4, 5), (5, 6, 2), "BFD,DNH->BFNH"),
+              ((7, 11, 13), (7, 11, 13, 5), "BIN,BINJ->BIJ"),
+              ((7, 11, 19), (7, 11, 13, 19), "BIJ,BINJ->BIN"),
+              ((5, 13, 3, 11), (5, 11, 13, 8), "ACBE,AECD->ABCD"),
+              ((5, 11, 7, 3), (5, 8, 7, 3), "AECD,ABCD->ACBE"),
+              ((5, 4, 3), (3, 2, 1), "...ij,j...->i..."),
+              ((5, 4, 3), (3, 2, 1), "...ij,j...->...i"),
+              ((1, 11, 19), (7, 11, 13, 19), "...IJ,...INJ->...IN"),
+              ((1, 11, 19), (7, 11, 13, 19), "...IJ,...INJ->IN..."),
+              ((4, 3, 2, 5), (3, 6, 1), "ij...,jk...->ik..."),
+              ((4, 3, 2, 5), (3, 6, 1), "ij...,jk...->...ik"),
+          ],
       },
   ]