Support legalizing Einsum op with ellipsis in both lhs and rhs
The broadcasting semantics is handled by TF::BatchMatMulV2Op.
PiperOrigin-RevId: 383187306
Change-Id: I3f997b15c69e6228f85d594c484f028f6095ec98
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
index 973e46d..50332ef 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir
@@ -200,3 +200,29 @@
// CHECK-LABEL: einsum_ellipsis
// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32>
}
+
+func @einsum_ellipsis_in_both_sides(%arg0: tensor<1x11x19xf32>, %arg1: tensor<7x11x13x19xf32>) -> tensor<7x11x13xf32> {
+ %0 = "tf.Einsum"(%arg0, %arg1) {device = "", equation = "...IJ,...INJ->...IN"} : (tensor<1x11x19xf32>, tensor<7x11x13x19xf32>) -> tensor<7x11x13xf32>
+ return %0 : tensor<7x11x13xf32>
+ // CHECK-LABEL: einsum_ellipsis_in_both_sides
+ // CHECK: %[[cst:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32>
+ // CHECK: %[[cst_1:.*]] = constant dense<[1, 11, 1, 19]> : tensor<4xi64>
+ // CHECK: %[[cst_2:.*]] = constant dense<[7, 11, 13]> : tensor<3xi64>
+ // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<7x11x13x19xf32>, tensor<4xi32>) -> tensor<7x11x19x13xf32>
+ // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<1x11x19xf32>, tensor<4xi64>) -> tensor<1x11x1x19xf32>
+ // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<1x11x1x19xf32>, tensor<7x11x19x13xf32>) -> tensor<7x11x1x13xf32>
+ // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<7x11x1x13xf32>, tensor<3xi64>) -> tensor<7x11x13xf32>
+ // CHECK: return %[[v3]] : tensor<7x11x13xf32>
+}
+
+func @einsum_ellipsis_with_broadcast(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<4x2x5xf32> {
+ %0 = "tf.Einsum"(%arg0, %arg1) {device = "", equation = "...ij,j...->i..."} : (tensor<5x4x3xf32>, tensor<3x2x1xf32>) -> tensor<4x2x5xf32>
+ return %0 : tensor<4x2x5xf32>
+ // CHECK-LABEL: einsum_ellipsis_with_broadcast
+ // CHECK: %[[cst:.*]] = constant dense<[2, 0, 1]> : tensor<3xi32>
+ // CHECK: %[[cst_1:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32>
+ // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<3x2x1xf32>, tensor<3xi32>) -> tensor<1x3x2xf32>
+ // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<5x4x3xf32>, tensor<1x3x2xf32>) -> tensor<5x4x2xf32>
+ // CHECK: %[[v2:.*]] = "tf.Transpose"(%[[v1]], %[[cst_1]]) : (tensor<5x4x2xf32>, tensor<3xi32>) -> tensor<4x2x5xf32>
+ // CHECK: return %[[v2]] : tensor<4x2x5xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
index a506373..0d87d19 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc
@@ -21,6 +21,7 @@
#include <cstdint>
#include <string>
#include <tuple>
+#include <utility>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
@@ -35,6 +36,7 @@
#include "llvm/Support/Regex.h"
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
@@ -123,7 +125,6 @@
labels.insert('a' + i);
labels.insert('A' + i);
}
- bool ellipsis_observed = false;
auto is_start_of_ellipsis = [](StringRef equation, int start_index) {
if (equation.size() < (start_index + 3)) return false;
@@ -141,14 +142,12 @@
++lhs_count;
} else if (label == '.') {
if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
- ellipsis_observed = true;
i += 2;
} else {
// Unsupported character in the equation.
return llvm::None;
}
}
-
*lhs_named_label_count = lhs_count;
int rhs_count = 0;
@@ -159,11 +158,8 @@
labels.remove(label);
++rhs_count;
} else if (label == '.') {
- if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
- // We do not support both lhs & rhs have ellipsis for now.
- if (ellipsis_observed) return llvm::None;
+ if (!is_start_of_ellipsis(rhs, i)) return llvm::None;
i += 2;
- ellipsis_observed = true;
} else {
// Unsupported character in the equation.
return llvm::None;
@@ -178,11 +174,11 @@
// For example, if we have GenerateLabels(2, {'b', 'c', 'd'}) for "...xy"
// We will have "dcxy" for the ellipsis expression since it's rank 4,
// we will have dcbxy if it's rank 5.
-std::string GenerateLabels(int count, llvm::SetVector<char>* available_labels) {
- std::string new_labels;
- new_labels.reserve(count);
+std::string GenerateLabels(int count,
+ const llvm::SetVector<char>& available_labels) {
+ std::string new_labels(count, 0);
for (int i = 0; i < count; ++i) {
- new_labels.push_back(available_labels->pop_back_val());
+ new_labels[count - 1 - i] = available_labels[i];
}
return new_labels;
@@ -191,7 +187,7 @@
std::tuple<std::string, std::string, std::string> FlattenEllipsis(
llvm::StringRef lhs, int lhs_named_label_count, llvm::StringRef rhs,
int rhs_named_label_count, llvm::StringRef out, RankedTensorType lhs_ty,
- RankedTensorType rhs_ty, llvm::SetVector<char>* available_labels) {
+ RankedTensorType rhs_ty, const llvm::SetVector<char>& available_labels) {
std::string new_labels;
std::string new_lhs;
for (int i = 0; i < lhs.size(); ++i) {
@@ -208,7 +204,7 @@
}
}
- std::string new_rhs;
+ std::string new_rhs, new_rhs_labels;
for (int i = 0; i < rhs.size(); ++i) {
const char label = rhs[i];
if (std::isalpha(label)) {
@@ -216,10 +212,13 @@
} else {
// Encounter ellipsis: generate unnamed labels then insert to the new
// labels.
- new_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
- available_labels);
- new_rhs.append(new_labels);
+ new_rhs_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
+ available_labels);
+ new_rhs.append(new_rhs_labels);
i += 2;
+ if (new_rhs_labels.size() > new_labels.size()) {
+ new_labels = new_rhs_labels;
+ }
}
}
@@ -254,10 +253,6 @@
if (lhs.empty() || rhs.empty()) return llvm::None;
// Try to flatten the "..." if possible.
- // Currently we only support either lhs or the rhs has "..." but not both.
- // Both usually will require a broadcasting semantics which is not supported
- // by the batch_matmul.
- // TODO(b/181244617): Consider handling the broadcasting scenario as well.
int lhs_named_label, rhs_named_label;
auto avaiable_labels =
GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label);
@@ -265,7 +260,7 @@
auto flattended_labels =
FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty,
- rhs_ty, &avaiable_labels.getValue());
+ rhs_ty, avaiable_labels.getValue());
lhs = std::get<0>(flattended_labels);
rhs = std::get<1>(flattended_labels);
@@ -383,27 +378,35 @@
// B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
LogicalResult reshapeForBatchMatmul(const Location& loc,
EinsumDimensionNumbers& dnums, Value* lhs,
- Value* rhs, std::vector<int64_t>* out_shape,
+ Value* rhs,
+ SmallVectorImpl<int64_t>* out_shape,
PatternRewriter* rewriter) {
RankedTensorType lhs_type = lhs->getType().cast<RankedTensorType>();
RankedTensorType rhs_type = rhs->getType().cast<RankedTensorType>();
+ // Labels exist in all lhs, rhs and output are the batch labels B0,...,Bn.
std::vector<int64_t> lhs_shape;
std::vector<int64_t> rhs_shape;
lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1);
rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2);
for (auto i : dnums.lhs_rhs_out) {
- int64_t b = lhs_type.getShape()[std::get<0>(i)];
- lhs_shape.push_back(b);
- rhs_shape.push_back(b);
- out_shape->push_back(b);
+ int64_t b1 = lhs_type.getShape()[std::get<0>(i)];
+ lhs_shape.push_back(b1);
+ int64_t b2 = rhs_type.getShape()[std::get<1>(i)];
+ rhs_shape.push_back(b2);
+ }
+ if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, *out_shape)) {
+ return failure();
}
+ // Calculates dimension for the label L from L0,...,Ln in lhs.
if (dnums.lhs_out.empty()) {
lhs_shape.push_back(1);
out_shape->push_back(1);
dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1);
} else if (dnums.lhs_rhs_out.empty()) {
+ // If there is not batch labels B0,...,Bn, it is safe to use L0,...,Ln as
+ // the batch labels in lhs, the rhs will be broadcasted.
for (auto i : dnums.lhs_out) {
int64_t b = lhs_type.getShape()[std::get<0>(i)];
lhs_shape.push_back(b);
@@ -418,14 +421,18 @@
out_shape->push_back(lhs_out_size);
}
- int64_t lhs_rhs_size = 1;
+ // Calculates dimension for the common label C from labels C0,...,Cn that
+ // exist in both lhs and rhs.
+ int64_t lhs_size = 1, rhs_size = 1;
for (auto i : dnums.lhs_rhs) {
- lhs_rhs_size *= lhs_type.getShape()[std::get<0>(i)];
+ lhs_size *= lhs_type.getShape()[std::get<0>(i)];
+ rhs_size *= rhs_type.getShape()[std::get<1>(i)];
}
- lhs_shape.push_back(lhs_rhs_size);
- rhs_shape.push_back(lhs_rhs_size);
+ lhs_shape.push_back(lhs_size);
+ rhs_shape.push_back(rhs_size);
- int64_t rhs_size = 1;
+ // Calculates dimension for the label R from R0,...,Rn in rhs.
+ rhs_size = 1;
for (auto i : dnums.rhs_out) {
rhs_size *= rhs_type.getShape()[std::get<0>(i)];
}
@@ -469,7 +476,7 @@
&out_transpose, &rewriter)))
return failure();
- std::vector<int64_t> matmul_shape;
+ llvm::SmallVector<int64_t, 4> matmul_shape;
if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
&matmul_shape, &rewriter)))
return failure();
diff --git a/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py b/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
index 2feb575..0e4356b 100644
--- a/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
+++ b/tensorflow/lite/experimental/mlir/testing/op_tests/einsum.py
@@ -32,17 +32,25 @@
test_parameters = [
{
"dtype": [tf.float32],
- "shapes": [((3, 4, 5), (3, 5, 6), "ijk,ikm->ijm"),
- ((3, 4, 5), (5, 6), "ijk,km->ijm"),
- ((2, 5, 7), (5, 2), "LBH,BL->BH"),
- ((2, 5, 7), (5, 3, 2), "LBH,BKL->BKH"),
- ((2, 5, 7, 3), (2, 4, 7, 3), "BFNH,BTNH->BNFT"),
- ((2, 5, 7, 3), (7, 3, 4), "BFND,NDH->BFH"),
- ((3, 4, 5), (5, 6, 2), "BFD,DNH->BFNH"),
- ((7, 11, 13), (7, 11, 13, 5), "BIN,BINJ->BIJ"),
- ((7, 11, 19), (7, 11, 13, 19), "BIJ,BINJ->BIN"),
- ((5, 13, 3, 11), (5, 11, 13, 8), "ACBE,AECD->ABCD"),
- ((5, 11, 7, 3), (5, 8, 7, 3), "AECD,ABCD->ACBE")],
+ "shapes": [
+ ((3, 4, 5), (3, 5, 6), "ijk,ikm->ijm"),
+ ((3, 4, 5), (5, 6), "ijk,km->ijm"),
+ ((2, 5, 7), (5, 2), "LBH,BL->BH"),
+ ((2, 5, 7), (5, 3, 2), "LBH,BKL->BKH"),
+ ((2, 5, 7, 3), (2, 4, 7, 3), "BFNH,BTNH->BNFT"),
+ ((2, 5, 7, 3), (7, 3, 4), "BFND,NDH->BFH"),
+ ((3, 4, 5), (5, 6, 2), "BFD,DNH->BFNH"),
+ ((7, 11, 13), (7, 11, 13, 5), "BIN,BINJ->BIJ"),
+ ((7, 11, 19), (7, 11, 13, 19), "BIJ,BINJ->BIN"),
+ ((5, 13, 3, 11), (5, 11, 13, 8), "ACBE,AECD->ABCD"),
+ ((5, 11, 7, 3), (5, 8, 7, 3), "AECD,ABCD->ACBE"),
+ ((5, 4, 3), (3, 2, 1), "...ij,j...->i..."),
+ ((5, 4, 3), (3, 2, 1), "...ij,j...->...i"),
+ ((1, 11, 19), (7, 11, 13, 19), "...IJ,...INJ->...IN"),
+ ((1, 11, 19), (7, 11, 13, 19), "...IJ,...INJ->IN..."),
+ ((4, 3, 2, 5), (3, 6, 1), "ij...,jk...->ik..."),
+ ((4, 3, 2, 5), (3, 6, 1), "ij...,jk...->...ik"),
+ ],
},
]