[TF:TRT] Refactor einsum converter

[TF:TRT] Make Einsum converter use DimsAdapter

[TF:TRT] Reactivate einsum test case for TRT >= 8.0
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index a52c67e..f692ffd 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -610,6 +610,7 @@
         "convert/convert_graph.cc",
         "convert/convert_nodes.cc",
         "convert/ops/data_format_vec_permute.cc",
+        "convert/ops/einsum.cc",
         "convert/ops/quantization_ops.cc",
         "convert/ops/slice_ops.cc",
         "convert/trt_optimization_pass.cc",
diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.h b/tensorflow/compiler/tf2tensorrt/common/utils.h
index fcdd440..b4040d3 100644
--- a/tensorflow/compiler/tf2tensorrt/common/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/common/utils.h
@@ -41,6 +41,19 @@
 #include "tensorflow/core/platform/status.h"
 #include "third_party/tensorrt/NvInfer.h"
 
+#define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
+  do {                                                               \
+    return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__,  \
+                            " failed to add TRT layer, at: ", node); \
+  } while (0)
+
+#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
+  do {                                           \
+    if (ptr == nullptr) {                        \
+      TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
+    }                                            \
+  } while (0)
+
 // Use this macro within functions that return a Status or StatusOR<T> to check
 // boolean conditions. If the condition fails, it returns an
 // errors::Internal message with the file and line number.
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 730271f..4bb0f0c 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -55,7 +55,6 @@
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
-#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/strings/numbers.h"
@@ -79,19 +78,6 @@
 // would work!
 #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
 
-#define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
-  do {                                                               \
-    return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__,  \
-                            " failed to add TRT layer, at: ", node); \
-  } while (0)
-
-#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
-  do {                                           \
-    if (ptr == nullptr) {                        \
-      TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
-    }                                            \
-  } while (0)
-
 namespace tensorflow {
 namespace tensorrt {
 namespace convert {
@@ -4989,578 +4975,6 @@
                              transpose_b);
 }
 
-// Finds the indices of elements in [begin, end) in array
-// [array_begin, array_end), and appends the indices to permute. This is used to
-// construct the permutation sequence for the operand with input labels
-// [array_begin, array_end) to the desired permuted labels [begin, end).
-template <typename Iterator>
-Status FindIndices(Iterator begin, Iterator end, Iterator array_begin,
-                   Iterator array_end, std::vector<int>* permute) {
-  const int n = array_end - array_begin;
-  if (n < end - begin) {
-    return errors::Internal("Incorrect array size");
-  }
-  for (auto i = begin; i < end; i++) {
-    int idx = std::find(array_begin, array_end, *i) - array_begin;
-    if (idx >= n) {
-      return errors::Internal("Label not found");
-    }
-    permute->push_back(idx);
-  }
-  return Status::OK();
-}
-
-#if IS_TRT_VERSION_GE(7, 1, 3, 0)
-// Layout of the einsum dimensions: Batch, Free and Contraction indices.
-// Example: abcd,adef -> abde. The first tensor has layout BFC, the second BCF.
-enum class EinsumLayout { BFC, BCF, MIX };
-
-// Describes an operand: input shape, number of batch, free and contract
-// dimensions, and the permutation that is needed to bring it to a matmul
-// compatible form.
-struct EinsumDescriptor {
-  EinsumDescriptor() : b(0), f(0), c(0) {}
-
-  // Deduces the number of batch, free, contract dimensions from the input
-  // labels, decides what layout to use, and determines permutation indices for
-  // that layout.
-  Status InitDescriptor(const TRT_TensorOrWeights& operand, Labels input_labels,
-                        std::vector<EinsumHelper::DimensionType>& label_types,
-                        EinsumLayout preferred_layout,
-                        EinsumDescriptor* other = nullptr) {
-    if (preferred_layout == EinsumLayout::MIX)
-      return errors::Internal("Preferred einsum layout cannot be MIX");
-    const EinsumHelper::DimensionType kBatch =
-        EinsumHelper::DimensionType::kBatch;
-    const EinsumHelper::DimensionType kFree =
-        EinsumHelper::DimensionType::kFree;
-    const EinsumHelper::DimensionType kContract =
-        EinsumHelper::DimensionType::kContract;
-
-    // Map label indices to label types.
-    std::vector<EinsumHelper::DimensionType> types;  // Input label types.
-    std::transform(input_labels.begin(), input_labels.end(),
-                   std::back_inserter(types),
-                   [&label_types, kBatch](int i) { return label_types.at(i); });
-
-    using label_t_iterator = std::vector<EinsumHelper::DimensionType>::iterator;
-    auto count_labels = [](label_t_iterator begin, label_t_iterator end,
-                           EinsumHelper::DimensionType val) {
-      return std::count_if(begin, end, [val](EinsumHelper::DimensionType t) {
-        return t == val;
-      });
-    };
-
-    b = count_labels(types.begin(), types.end(), kBatch);
-    f = count_labels(types.begin(), types.end(), kFree);
-    c = count_labels(types.begin(), types.end(), kContract);
-
-    if (c == 0 || f == 0) {
-      VLOG(2) << "Einsum equation needs to have at least one free and one "
-                 "contract dimension";
-      return errors::Unimplemented("No conversion for einsum equation.");
-    }
-
-    // Checks whether input_labels[offset:offset+m] matches labels from other.
-    auto order_matches = [other, &input_labels, kBatch, kFree, kContract](
-                             int offset, int m,
-                             EinsumHelper::DimensionType dim_type) {
-      if (!other) return true;
-      int offset_other = 0;
-      if (dim_type == kFree)
-        offset = other->offset_f;
-      else if (dim_type == kContract)
-        offset = other->offset_c;
-      return std::equal(input_labels.begin() + offset,
-                        input_labels.begin() + offset + m,
-                        other->permuted_labels.begin() + offset_other);
-    };
-
-    // Check if the current layout is BFC or BCF. In that case we could avoid
-    // transpose.
-    layout = EinsumLayout::MIX;
-    if (count_labels(types.begin(), types.begin() + b, kBatch) == b &&
-        order_matches(0, b, kBatch)) {
-      // Batch dims are the leading dims. They have the same order as other.
-      if (count_labels(types.begin() + b, types.begin() + b + f, kFree) == f) {
-        // All the free dims are placed consecutively after the batch dims.
-        // Their order is arbitrary. The final transpose will ensure that the
-        // output has correct order. We still have to check that the contract
-        // indices have correct order.
-        if (order_matches(b + f, c, kContract)) {
-          layout = EinsumLayout::BFC;
-        }
-      } else if (count_labels(types.begin() + b, types.begin() + b + c,
-                              kContract) == c) {
-        // All the contract dims are placed consecutively after the batch
-        // dims. Check whether the contract dims have the same order as the
-        // contract dims in other.
-        if (order_matches(b, c, kContract)) {
-          layout = EinsumLayout::BCF;
-        }
-      }
-    }
-
-    if (layout == EinsumLayout::MIX) {
-      // Input label types are mixed. Calculate a permutation that maps them
-      // to the preferred layout (BCF or BFC).
-      layout = preferred_layout;
-      if (!other) {
-        AppendMatchingIndicesToPermute(types, kBatch);
-      } else {
-        TF_RETURN_IF_ERROR(
-            FindIndices(other->permuted_labels.begin(),
-                        other->permuted_labels.begin() + other->b,
-                        input_labels.begin(), input_labels.end(), &permute));
-      }
-      if (layout == EinsumLayout::BFC) {
-        AppendMatchingIndicesToPermute(types, kFree);
-        if (!other) {
-          AppendMatchingIndicesToPermute(types, kContract);
-        } else {
-          TF_RETURN_IF_ERROR(FindIndices(
-              other->permuted_labels.begin() + other->offset_c,
-              other->permuted_labels.begin() + other->offset_c + other->c,
-              input_labels.begin(), input_labels.end(), &permute));
-        }
-      } else {
-        if (!other) {
-          AppendMatchingIndicesToPermute(types, kContract);
-        } else {
-          TF_RETURN_IF_ERROR(FindIndices(
-              other->permuted_labels.begin() + other->offset_c,
-              other->permuted_labels.begin() + other->offset_c + other->c,
-              input_labels.begin(), input_labels.end(), &permute));
-        }
-        AppendMatchingIndicesToPermute(types, kFree);
-      }
-    }
-
-    if (layout == EinsumLayout::BFC) {
-      offset_f = b;
-      offset_c = f + b;
-    } else {
-      offset_f = b + c;
-      offset_c = b;
-    }
-
-    dims = operand.GetTrtDims();
-    for (int i = 0; i < b; i++) {
-      // Set unknown batch dims to zero. These dims will be used in reshape op,
-      // where zero is a special value for retaining the original dim size.
-      if (dims.d[i] == -1) dims.d[i] = 0;
-    }
-    permuted_labels = input_labels;
-    if (!permute.empty()) {
-      // Apply the permutation on the dimension array.
-      nvinfer1::Dims orig_dims = dims;
-      for (int i = 0; i < permute.size(); i++) {
-        dims.d[i] = orig_dims.d[permute[i]];
-        permuted_labels[i] = input_labels[permute[i]];
-      }
-    }
-    size_tensors.resize(dims.nbDims, nullptr);
-
-    VLOG(2) << "Set up descriptor with "
-            << (layout == EinsumLayout::BFC ? "BFC" : "BCF")
-            << " layout, b=" << b << ", f=" << f << ", c=" << c;
-    return Status::OK();
-  }
-
-  // Appends indices where types maches value.
-  void AppendMatchingIndicesToPermute(
-      const std::vector<EinsumHelper::DimensionType>& types,
-      EinsumHelper::DimensionType val) {
-    for (int i = 0; i < types.size(); i++) {
-      if (types[i] == val) {
-        permute.push_back(i);
-      }
-    }
-  }
-
-  // Returns whether the free and contract dimension have static shape.
-  bool HasStaticShape() {
-    return !std::any_of(dims.d + b, dims.d + dims.nbDims,
-                        [](int k) { return k == -1; });
-  }
-
-  nvinfer1::Permutation GetPermutation() {
-    nvinfer1::Permutation p;
-    std::copy(permute.begin(), permute.end(), p.order);
-    return p;
-  }
-
-  Status SetDynamicSize(OpConverterParams* params,
-                        const TRT_TensorOrWeights& operand) {
-    if (operand.GetTrtDims().nbDims != dims.nbDims)
-      return errors::Internal("Operand dims must agree with descirptor dims");
-
-    if (operand.is_weights()) {
-      for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
-        // dims.d stores the permuted dims.
-        TF_RETURN_IF_ERROR(
-            CreateScalarConstant(params, dims.d[i], &size_tensors[i]));
-      }
-      return Status::OK();
-    }
-    auto* shape_layer =
-        params->converter->network()->addShape(*operand.tensor()->trt_tensor());
-    TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name());
-    ITensorProxyPtr shape = shape_layer->getOutput(0);
-    for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
-      int idx = permute.empty() ? i : permute.at(i);
-      auto* layer = params->converter->network()->addSlice(
-          *shape->trt_tensor(), {1, {idx}}, {1, {1}}, {1, {1}});
-      TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
-      size_tensors[i] = layer->getOutput(0);
-      TFTRT_RETURN_ERROR_IF_NULLPTR(size_tensors[i], "error, slice is nullptr");
-    }
-    return Status::OK();
-  }
-
-  EinsumLayout layout;
-  int b;  // number of batch dims
-  int f;  // number of free dims
-  int c;  // number of conraction dims
-  int offset_f;
-  int offset_c;
-  nvinfer1::Dims dims;
-  std::vector<int> permute;
-  std::vector<ITensorProxyPtr> size_tensors;
-  Labels permuted_labels;
-};
-
-Status GetDimsProd(nvinfer1::Dims dims, int offset, int n, int32_t* out) {
-  size_t prod = std::accumulate(dims.d + offset, dims.d + offset + n, size_t(1),
-                                std::multiplies<size_t>());
-  if (prod > std::numeric_limits<int32_t>::max()) {
-    return errors::Internal("Matrix too large");
-  } else {
-    *out = prod;
-  }
-  return Status::OK();
-}
-
-Status GetDimsProdDynamic(OpConverterParams* params,
-                          std::vector<ITensorProxyPtr>::const_iterator begin,
-                          std::vector<ITensorProxyPtr>::const_iterator end,
-                          ITensorProxyPtr* out) {
-  *out = *begin;
-  begin++;
-  while (begin != end) {
-    nvinfer1::IElementWiseLayer* layer =
-        params->converter->network()->addElementWise(
-            *(*out)->trt_tensor(), *(*begin)->trt_tensor(),
-            nvinfer1::ElementWiseOperation::kPROD);
-    TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
-    *out = layer->getOutput(0);
-    begin++;
-  }
-  return Status::OK();
-}
-
-Status ConcatenateShape(OpConverterParams* params,
-                        const std::vector<ITensorProxyPtr> size_tensors,
-                        ITensorProxyPtr* new_shape) {
-  std::vector<nvinfer1::ITensor*> trt_size_tensors;
-  for (const auto& t : size_tensors) {
-    trt_size_tensors.push_back(t->trt_tensor());
-  }
-  nvinfer1::IConcatenationLayer* layer =
-      params->converter->network()->addConcatenation(
-          static_cast<nvinfer1::ITensor* const*>(trt_size_tensors.data()),
-          size_tensors.size());
-  TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
-  layer->setAxis(0);
-  *new_shape = layer->getOutput(0);
-  return Status::OK();
-}
-
-// Reshapes operand so that the free dimensions are combined into a single dim,
-// and the contract dimensions are combined into another single dim.
-Status GetEinsumNewDynamicShape(OpConverterParams* params,
-                                const EinsumDescriptor& desc,
-                                ITensorProxyPtr* new_shape) {
-  std::vector<ITensorProxyPtr> size(desc.size_tensors.begin(),
-                                    desc.size_tensors.begin() + desc.b + 2);
-
-  int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
-  int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
-
-  TF_RETURN_IF_ERROR(GetDimsProdDynamic(
-      params, desc.size_tensors.begin() + desc.offset_f,
-      desc.size_tensors.begin() + desc.offset_f + desc.f, &size[idx_f]));
-
-  TF_RETURN_IF_ERROR(GetDimsProdDynamic(
-      params, desc.size_tensors.begin() + desc.offset_c,
-      desc.size_tensors.begin() + desc.offset_c + desc.c, &size[idx_c]));
-
-  TF_RETURN_IF_ERROR(ConcatenateShape(params, size, new_shape));
-  return Status::OK();
-}
-
-// Reshapes operand so that the free dimensions are combined into a single dim,
-// and the contract dimensions are combined into another single dim.
-Status GetEinsumNewStaticShape(const EinsumDescriptor& desc,
-                               nvinfer1::Dims* new_dims) {
-  new_dims->nbDims = desc.b + 2;
-  // Copy batch dims.
-  std::copy(desc.dims.d, desc.dims.d + desc.b, new_dims->d);
-  // Combine free dims and contract dims.
-  int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
-  int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
-  TF_RETURN_IF_ERROR(
-      GetDimsProd(desc.dims, desc.offset_f, desc.f, new_dims->d + idx_f));
-  TF_RETURN_IF_ERROR(
-      GetDimsProd(desc.dims, desc.offset_c, desc.c, new_dims->d + idx_c));
-  return Status::OK();
-}
-
-// Adds shuffle layer (if needed) to bring einsum operand to a matmul compatible
-// format.
-Status ShuffleEinsumTensor(OpConverterParams* params,
-                           std::unique_ptr<TRT_TensorOrWeights>* operand,
-                           EinsumDescriptor* desc, int op_instance) {
-  if (params->validation_only) return Status::OK();
-  TF_RETURN_IF_ERROR(desc->SetDynamicSize(params, **operand));
-  bool need_reshape = (desc->f != 1 || desc->c != 1);
-  bool need_transpose = !desc->permute.empty();
-  if ((*operand)->is_weights()) {
-    nvinfer1::Dims new_dims;
-    TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(*desc, &new_dims));
-    if (!need_transpose) {
-      TRT_ShapedWeights weights((*operand)->weights());
-      TF_RETURN_IF_ERROR(weights.SetShape(new_dims));
-      operand->reset(new TRT_TensorOrWeights(weights));
-      return Status::OK();
-    }
-    // TODO(tfeher): Instead of creating a tensor that will be transposed,
-    // transpose the weight itself. Keeping it weight could enable FC layer.
-    ITensorProxyPtr tensor = params->converter->CreateConstantLayer(
-        (*operand)->weights(), (*operand)->GetTrtDims());
-    operand->reset(new TRT_TensorOrWeights(tensor));
-  }
-
-  if (!need_transpose && !need_reshape) return Status::OK();
-  ITensorProxyPtr operand_tensor = (*operand)->tensor();
-  TFTRT_RETURN_ERROR_IF_NULLPTR(operand_tensor, "Null tensor at Einsum");
-  nvinfer1::IShuffleLayer* layer =
-      params->converter->network()->addShuffle(*operand_tensor->trt_tensor());
-
-  TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
-  params->converter->SetLayerName(layer, params->node_def, "shuffle",
-                                  /*op_instance=*/op_instance);
-  // Set new shape.
-  if (need_reshape) {
-    if (desc->HasStaticShape()) {
-      nvinfer1::Dims new_dims;
-      TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(*desc, &new_dims));
-      layer->setReshapeDimensions(new_dims);
-    } else {
-      ITensorProxyPtr new_shape;
-      TF_RETURN_IF_ERROR(GetEinsumNewDynamicShape(params, *desc, &new_shape));
-      layer->setInput(1, *new_shape->trt_tensor());
-    }
-  }
-
-  if (need_transpose) {
-    layer->setFirstTranspose(desc->GetPermutation());
-  }
-  operand->reset(new TRT_TensorOrWeights(layer->getOutput(0)));
-  return Status::OK();
-}
-
-// Combines output dims/labels by copying batch and free dims/labels from input
-// A, and concatenating free values from input B.
-template <typename InputIterator, typename OutputIterator>
-void AssembleOutput(InputIterator begin_a, InputIterator begin_b,
-                    const EinsumDescriptor& desc_a,
-                    const EinsumDescriptor& desc_b, OutputIterator out) {
-  std::copy(begin_a, begin_a + desc_a.b, out);
-  begin_a += desc_a.offset_f;
-  std::copy(begin_a, begin_a + desc_a.f, out + desc_a.b);
-  begin_b += desc_b.offset_f;
-  std::copy(begin_b, begin_b + desc_b.f, out + desc_a.b + desc_a.f);
-}
-
-// Restores free dimensions and sets final index order. Consider C = A * B,
-// batched MatMul op, where A.shape = [B, x, k] and B.shape = [B, k, y]. Then
-// C.shape = [B, x, y]. Here B can denote multiple batch indices while x, y, k
-// are single indices. The original inputs to Einsum can have multiple free
-// indices. These were combined into a singe free dimension x and y, for example
-// x = f_a1 * f_a2 * f_a3, y = f_b1 * f_b2. This routine creates a shuffle layer
-// to expand x into and y the original free dims, e.g. C is reshaped to
-// [B, f_a1, f_a2, f_a3, f_b1, f_b2]. Finally, a permutation is applied to
-// transform the shape to the shape of the original Einsum output.
-Status ShuffleEinsumOutput(OpConverterParams* params, EinsumDescriptor desc_a,
-                           EinsumDescriptor desc_b,
-                           const std::vector<int>& permutation,
-                           ITensorProxyPtr* output) {
-  if (permutation.empty() && (desc_a.f == 1 && desc_b.f == 1))
-    return Status::OK();
-
-  nvinfer1::IShuffleLayer* layer =
-      params->converter->network()->addShuffle(*(*output)->trt_tensor());
-  TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
-  params->converter->SetLayerName(layer, params->node_def, "shuffle",
-                                  /*op_instance=*/2);
-
-  int output_rank = desc_a.b + desc_a.f + desc_b.f;
-  if (desc_a.f != 1 || desc_b.f != 1) {
-    if (desc_a.HasStaticShape() && desc_b.HasStaticShape()) {
-      nvinfer1::Dims dims_out = {output_rank, {}};
-      AssembleOutput(desc_a.dims.d, desc_b.dims.d, desc_a, desc_b, dims_out.d);
-      layer->setReshapeDimensions(dims_out);
-    } else {
-      std::vector<ITensorProxyPtr> size_tensors(output_rank);
-      AssembleOutput(desc_a.size_tensors.begin(), desc_b.size_tensors.begin(),
-                     desc_a, desc_b, size_tensors.begin());
-      ITensorProxyPtr new_shape;
-      TF_RETURN_IF_ERROR(ConcatenateShape(params, size_tensors, &new_shape));
-      layer->setInput(1, *new_shape->trt_tensor());
-    }
-  }
-
-  if (!permutation.empty()) {
-    nvinfer1::Permutation p;
-    std::copy(permutation.begin(), permutation.end(), p.order);
-    layer->setSecondTranspose(p);
-  }
-  *output = layer->getOutput(0);
-  return Status::OK();
-}
-
-// Prepares EinsumDescriptors after parsing the equation and determines the
-// final transpose.
-Status ParseEquation(OpConverterParams* params,
-                     std::unique_ptr<TRT_TensorOrWeights>* input_a,
-                     std::unique_ptr<TRT_TensorOrWeights>* input_b,
-                     EinsumDescriptor* descriptor_a,
-                     EinsumDescriptor* descriptor_b,
-                     std::vector<int>* final_transpose) {
-  string equation;
-  TF_RETURN_IF_ERROR(
-      GetNodeAttr(AttrSlice(params->node_def), "equation", &equation));
-  VLOG(2) << "Einsum equation " << equation;
-
-  OperandLabels input_labels;
-  Labels output_labels;
-  std::vector<EinsumHelper::DimensionType> label_types;
-  OperandLabelCounts input_label_counts;
-  LabelCounts output_label_counts;
-  absl::InlinedVector<bool, 2> input_has_ellipsis;
-  bool output_has_ellipsis;
-  TF_RETURN_IF_ERROR(EinsumHelper::ParseEquation(
-      equation, &input_labels, &output_labels, &label_types,
-      &input_label_counts, &output_label_counts, &input_has_ellipsis,
-      &output_has_ellipsis));
-
-  VLOG(2) << "Output has ellipsis: " << output_has_ellipsis;
-
-  if (input_has_ellipsis[0] || input_has_ellipsis[1] || output_has_ellipsis) {
-    // TODO(tfeher): Handle ellipsis like EinsumHelper::ProcessDimensions.
-    // Note: ProcessDimensions would introduce kBroadcasting labels, which we
-    // need to replace with kBatch before we call InitDescriptor.
-    VLOG(2) << "Ellipsis not yet supported";
-    return errors::Unimplemented("No conversion for einsum equation.");
-  }
-  if (absl::c_any_of(label_types, [](auto l) {
-        return l == EinsumHelper::DimensionType::kReduce ||
-               l == EinsumHelper::DimensionType::kBroadcasting;
-      })) {
-    VLOG(2) << "Einsum reductions not implemented";
-    return errors::Unimplemented("No conversion for einsum equation.");
-  }
-
-  auto no_duplicated_labels = [](const LabelCounts& label_counts) {
-    return absl::c_any_of(label_counts, [](int i) { return i > 1; });
-  };
-  if (no_duplicated_labels(input_label_counts[0]) ||
-      no_duplicated_labels(input_label_counts[1]) ||
-      no_duplicated_labels(output_label_counts)) {
-    VLOG(2) << "Einsum invalid label count";
-    return errors::Unimplemented("No conversion for einsum equation.");
-  }
-
-  if ((*input_a)->is_weights() && (*input_b)->is_tensor()) {
-    // We prefer to use FC layer, needs A as tensor and B as weight.
-    std::swap(*input_a, *input_b);
-    std::swap(input_labels[0], input_labels[1]);
-    std::swap(input_label_counts[0], input_label_counts[1]);
-  }
-
-  TF_RETURN_IF_ERROR(descriptor_a->InitDescriptor(
-      **input_a, input_labels[0], label_types, EinsumLayout::BFC));
-  TF_RETURN_IF_ERROR(
-      descriptor_b->InitDescriptor(**input_b, input_labels[1], label_types,
-                                   EinsumLayout::BCF, descriptor_a));
-  // TODO(tfeher): Update the permutation in the descriptors to avoid final
-  // transpose (if possible). Consider swapping the input if it eliminates
-  // final transpose.
-
-  // Get final transpose.
-  Labels matmul_output_labels(descriptor_a->b + descriptor_a->f +
-                              descriptor_b->f);
-  AssembleOutput(descriptor_a->permuted_labels.begin(),
-                 descriptor_b->permuted_labels.begin(), *descriptor_a,
-                 *descriptor_b, matmul_output_labels.begin());
-  TF_RETURN_IF_ERROR(FindIndices(output_labels.begin(), output_labels.end(),
-                                 matmul_output_labels.begin(),
-                                 matmul_output_labels.end(), final_transpose));
-  // Clear identity transpose.
-  bool identity_transpose = true;
-  for (int i = 0; i < final_transpose->size() && identity_transpose; i++) {
-    identity_transpose &= final_transpose->at(i) == i;
-  }
-  if (identity_transpose) {
-    final_transpose->clear();
-  }
-  return Status::OK();
-}
-
-Status ConvertEinsum(OpConverterParams* params) {
-  const auto& inputs = params->inputs;
-  const auto& node_def = params->node_def;
-  if (params->use_implicit_batch) {
-    return errors::Unimplemented(
-        "Einsum converter requires dynamic shape mode");
-  }
-
-  if (inputs.size() != 2) {
-    VLOG(2) << "Einsum converter supports two operands at " << node_def.name()
-            << " got " << inputs.size();
-    return errors::Unimplemented("No conversion for einsum equation.");
-  }
-  TF_RETURN_IF_ERROR(
-      AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
-
-  auto input_a = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
-  auto input_b = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
-  EinsumDescriptor descriptor_a;
-  EinsumDescriptor descriptor_b;
-  std::vector<int> final_transpose;
-  TF_RETURN_IF_ERROR(ParseEquation(params, &input_a, &input_b, &descriptor_a,
-                                   &descriptor_b, &final_transpose));
-
-  TF_RETURN_IF_ERROR(ShuffleEinsumTensor(params, &input_a, &descriptor_a,
-                                         /*op_instance=*/0));
-  TF_RETURN_IF_ERROR(ShuffleEinsumTensor(params, &input_b, &descriptor_b,
-                                         /*op_instance=*/1));
-  if (params->validation_only) return Status::OK();
-
-  StatusOr<ITensorProxyPtr> result = ConvertMatMulImpl(
-      params, *input_a, *input_b, descriptor_a.layout == EinsumLayout::BCF,
-      descriptor_b.layout == EinsumLayout::BFC);
-  TF_RETURN_IF_ERROR(result.status());
-  ITensorProxyPtr output = result.ValueOrDie();
-
-  TF_RETURN_IF_ERROR(ShuffleEinsumOutput(params, descriptor_a, descriptor_b,
-                                         final_transpose, &output));
-  params->outputs->push_back(TRT_TensorOrWeights(output));
-  return Status::OK();
-}
-#endif  // IS_TRT_VERSION_GE(7, 1, 3, 0)
-
 Status ConvertSoftmax(OpConverterParams* params) {
   const auto& inputs = params->inputs;
   const auto& node_def = params->node_def;
@@ -6586,9 +6000,7 @@
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertDepthSpaceShuffle, "DepthToSpace");
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv2DDepthwise,
                                   "DepthwiseConv2dNative");
-#if IS_TRT_VERSION_GE(7, 1, 3, 0)
-REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertEinsum, "Einsum");
-#endif
+
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertExpandDims, "ExpandDims");
 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertFusedConv2DBiasActivation,
                                   "FusedConv2DBiasActivation");
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 6d152f6..f59bce9 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -516,6 +516,14 @@
         {"Pow", nvinfer1::ElementWiseOperation::kPOW},
     }};
 
+// Adds a matrix multiplication operation to the TensorRT graph. The "params"
+// pointer is only used to access the TRT network builder. The inputs and
+// parameters for the op are fully specified by input_[a|b] and transpose_[a|b].
+StatusOr<ITensorProxyPtr> ConvertMatMulImpl(OpConverterParams* params,
+                                            TRT_TensorOrWeights input_a,
+                                            TRT_TensorOrWeights input_b,
+                                            bool transpose_a, bool transpose_b);
+
 }  // namespace convert
 }  // namespace tensorrt
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 36ea402..4e1d25a 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -28,6 +28,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+
 #include "absl/algorithm/container.h"
 #include "absl/base/call_once.h"
 #include "absl/strings/match.h"
@@ -36,8 +37,6 @@
 #include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/nn_ops_internal.h"
@@ -67,6 +66,8 @@
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
 #include "tensorflow/core/public/session.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
 #include "third_party/tensorrt/NvInfer.h"
 
 namespace tensorflow {
@@ -2700,124 +2701,142 @@
     Status conv_status;
   };
 
-  Status unimplemented_eq =
-      errors::Unimplemented("No conversion for einsum equation.");
+  Status unimplemented_eq = errors::Unimplemented("");
+  Status internal_eq = errors::Internal("");
 
-  std::vector<TestParams> params {
-    // Dot product.
-    TestParams{"i,i->", {2}, {2, 3}, {2}, {1, 2}, {1}, {8}, unimplemented_eq},
-        // Outer product.
-        TestParams{"i,k->ik",
-                   {2},
-                   {1, 2},
-                   {3},
-                   {1, 2, 3},
-                   {2, 3},
-                   {1, 2, 3, 2, 4, 6},
-                   unimplemented_eq},
-        // Transpose.
-        TestParams{"ik->ki", {2, 3}, {0, 1, 2, 3, 4, 5}, {},
-                   {},       {3, 2}, {0, 3, 1, 4, 2, 5}, unimplemented_eq},
-        // Diag.
-        TestParams{"ii->i",
-                   {3, 3},
-                   {0, 1, 2, 3, 4, 5, 6, 7, 8},
-                   {},
-                   {},
-                   {3},
-                   {0, 4, 8},
-                   unimplemented_eq},
-        // Trace.
-        TestParams{
-            "ii", {3, 3},          {0, 1, 2, 3, 4, 5, 6, 7, 8}, {}, {}, {},
-            {12}, unimplemented_eq},
-        // MatMul with reduction.
-        TestParams{"abbc,dc->ad",
-                   {1, 2, 2, 3},
-                   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-                   {2, 3},
-                   {1, 2, 3, 4, 5, 6},
-                   {2, 3},
-                   {1, 2, 3, 2, 4, 6},
-                   unimplemented_eq},
-        // Ellipsis with broadcast.
-        TestParams{"...ik,...jk->...ij",
-                   {1, 3, 1, 4},
-                   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
-                   {2, 1, 1, 4},
-                   {1, 2, 3, 4, 5, 6, 7, 8},
-                   {2, 3, 1, 1},
-                   {20, 60, 100, 44, 148, 252},
-                   unimplemented_eq},
-        // MatMul and Batched MatMul.
-        TestParams{"ab,bc->ac",        {2, 3}, {0, 1, 2, 3, 4, 5}, {3, 2},
-                   {1, 2, 3, 4, 5, 6}, {2, 2}, {13, 16, 40, 52}},
-        TestParams{"abc,cde->abde",
-                   {1, 2, 3},
-                   {0, 1, 2, 3, 4, 5},
-                   {3, 2, 2},
-                   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-                   {1, 2, 2, 2},
-                   {23, 26, 29, 32, 68, 80, 92, 104}},
-        TestParams{"abcd,cde->abe",
-                   {1, 2, 2, 3},
-                   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
-                   {2, 3, 2},
-                   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-                   {1, 2, 2},
-                   {125, 140, 341, 392}},
-        TestParams{"abc,cd->abd",      {1, 2, 3}, {0, 1, 2, 3, 4, 5}, {3, 2},
-                   {1, 2, 3, 4, 5, 6}, {1, 2, 2}, {13, 16, 40, 52}},
-        TestParams{"acbe,aecd->abcd",
-                   {1, 2, 3, 4},
-                   {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                    12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
-                   {1, 4, 2, 3},
-                   {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
-                    13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
-                   {1, 3, 2, 3},
-                   {90, 96, 102, 732, 786, 840, 250, 272, 294, 940, 1010, 1080,
-                    410, 448, 486, 1148, 1234, 1320}},
-        TestParams{
-            "aecd,abcd->acbe",
-            {1, 2, 3, 4},
-            {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-             12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
-            {1, 2, 3, 4},
-            {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
-             13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
-            {1, 3, 2, 2},
-            {20, 140, 92, 788, 148, 460, 412, 1300, 404, 908, 860, 1940}},
-        TestParams{"acd,dce->ae",
-                   {1, 2, 3},
-                   {0, 1, 2, 3, 4, 5},
-                   {3, 2, 2},
-                   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-                   {1, 2},
-                   {115, 130}},
-        TestParams{"abcd,bace->bade",
-                   {2, 3, 2, 1},
-                   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
-                   {3, 2, 2, 1},
-                   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
-                   {3, 2, 1, 1},
-                   {2, 46, 28, 128, 86, 242}},
-#if !IS_TRT_VERSION_GE(8, 0, 0, 0)
-        // Deactivating buggy test case for TRT8 per nvbug 3322485.
-        TestParams{"cebfad,fageb->abcdg",
-                   {1, 1, 3, 3, 2, 2},
-                   {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                    12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
-                    24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
-                   {3, 2, 2, 1, 3},
-                   {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
-                    13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
-                    25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36},
-                   {2, 3, 1, 2, 2},
-                   {252,  288,  291,  336,  768,  912,  810,  963,
-                    1356, 1608, 1401, 1662, 438,  492,  495,  558,
-                    1176, 1338, 1236, 1407, 1986, 2256, 2049, 2328}},
-#endif
+  std::vector<TestParams> params{
+      // Dot product.
+      TestParams{"i,i->", {2}, {2, 3}, {2}, {1, 2}, {1}, {8}, unimplemented_eq},
+      // Outer product.
+      TestParams{"i,k->ik",
+                 {2},
+                 {1, 2},
+                 {3},
+                 {1, 2, 3},
+                 {2, 3},
+                 {1, 2, 3, 2, 4, 6},
+                 unimplemented_eq},
+      // Transpose.
+      TestParams{"ik->ki",
+                 {2, 3},
+                 {0, 1, 2, 3, 4, 5},
+                 {},
+                 {},
+                 {3, 2},
+                 {0, 3, 1, 4, 2, 5},
+                 internal_eq},
+      // Diag.
+      TestParams{"ii->i",
+                 {3, 3},
+                 {0, 1, 2, 3, 4, 5, 6, 7, 8},
+                 {},
+                 {},
+                 {3},
+                 {0, 4, 8},
+                 internal_eq},
+      // Trace.
+      TestParams{"ii",
+                 {3, 3},
+                 {0, 1, 2, 3, 4, 5, 6, 7, 8},
+                 {},
+                 {},
+                 {},
+                 {12},
+                 internal_eq},
+      // MatMul with reduction.
+      TestParams{"abbc,dc->ad",
+                 {1, 2, 2, 3},
+                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+                 {2, 3},
+                 {1, 2, 3, 4, 5, 6},
+                 {2, 3},
+                 {1, 2, 3, 2, 4, 6},
+                 unimplemented_eq},
+      // Ellipsis with broadcast.
+      TestParams{"...ik,...jk->...ij",
+                 {1, 3, 1, 4},
+                 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
+                 {2, 1, 1, 4},
+                 {1, 2, 3, 4, 5, 6, 7, 8},
+                 {2, 3, 1, 1},
+                 {20, 60, 100, 44, 148, 252},
+                 unimplemented_eq},
+      // MatMul
+      TestParams{"ab,bc->ac",
+                 {2, 3},
+                 {0, 1, 2, 3, 4, 5},
+                 {3, 2},
+                 {1, 2, 3, 4, 5, 6},
+                 {2, 2},
+                 {13, 16, 40, 52}},
+      // Batched MatMul
+      TestParams{"abc,cde->abde",
+                 /*shape_a=*/{1, 2, 3},
+                 /*values_a=*/{0, 1, 2, 3, 4, 5},
+                 /*shape_b=*/{3, 2, 2},
+                 /*values_v=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+                 /*expected_shape=*/{1, 2, 2, 2},
+                 /*expected_output=*/{23, 26, 29, 32, 68, 80, 92, 104}},
+      TestParams{"abcd,cde->abe",
+                 {1, 2, 2, 3},
+                 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
+                 {2, 3, 2},
+                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+                 {1, 2, 2},
+                 {125, 140, 341, 392}},
+      TestParams{"abc,cd->abd",
+                 {1, 2, 3},
+                 {0, 1, 2, 3, 4, 5},
+                 {3, 2},
+                 {1, 2, 3, 4, 5, 6},
+                 {1, 2, 2},
+                 {13, 16, 40, 52}},
+      TestParams{"acbe,aecd->abcd",
+                 {1, 2, 3, 4},
+                 {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                  12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+                 {1, 4, 2, 3},
+                 {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+                  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
+                 {1, 3, 2, 3},
+                 {90, 96, 102, 732, 786, 840, 250, 272, 294, 940, 1010, 1080,
+                  410, 448, 486, 1148, 1234, 1320}},
+      TestParams{"aecd,abcd->acbe",
+                 {1, 2, 3, 4},
+                 {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                  12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+                 {1, 2, 3, 4},
+                 {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+                  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24},
+                 {1, 3, 2, 2},
+                 {20, 140, 92, 788, 148, 460, 412, 1300, 404, 908, 860, 1940}},
+      TestParams{"acd,dce->ae",
+                 {1, 2, 3},
+                 {0, 1, 2, 3, 4, 5},
+                 {3, 2, 2},
+                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+                 {1, 2},
+                 {115, 130}},
+      TestParams{"abcd,bace->bade",
+                 {2, 3, 2, 1},
+                 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
+                 {3, 2, 2, 1},
+                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
+                 {3, 2, 1, 1},
+                 {2, 46, 28, 128, 86, 242}},
+      TestParams{
+          "cebfad,fageb->abcdg",
+          {1, 1, 3, 3, 2, 2},
+          {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+           24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35},
+          {3, 2, 2, 1, 3},
+          {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+           13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
+           25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36},
+          {2, 3, 1, 2, 2},
+          {252, 288, 291, 336, 768,  912,  810,  963,  1356, 1608, 1401, 1662,
+           438, 492, 495, 558, 1176, 1338, 1236, 1407, 1986, 2256, 2049, 2328}},
   };
 
   for (auto p : params) {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc
new file mode 100644
index 0000000..9617622
--- /dev/null
+++ b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc
@@ -0,0 +1,745 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+#include <iterator>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
+#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
+#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "third_party/tensorrt/NvInfer.h"
+
+#if IS_TRT_VERSION_GE(7, 1, 3, 0)
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+namespace {
+// Finds the indices of elements in [begin, end) in array
+// [array_begin, array_end), and appends the indices to permute. This is used to
+// construct the permutation sequence for the operand with input labels
+// [array_begin, array_end) to the desired permuted labels [begin, end).
+template <typename T>
+Status FindIndicesoOfAllValuesInSrc(absl::Span<const T> values,
+                                    absl::Span<const T> src,
+                                    std::vector<int>* indices) {
+  if (src.size() < values.size()) {
+    return errors::Internal(
+        "Span 'src' cannot contain all elements of 'values'");
+  }
+  for (auto i = 0; i < values.size(); i++) {
+    auto iter = absl::c_find(src, values[i]);
+    if (iter == src.end()) {
+      return errors::Internal("Label ", values[i], " not found");
+    }
+    int idx = std::distance(src.begin(), iter);
+    indices->push_back(idx);
+  }
+  return Status::OK();
+}
+
+// Layout of the einsum dimensions: Batch, Free and Contraction indices.
+// Example: adbc,adce -> adbe. The first tensor has layout BFC, the second BCF.
+enum class EinsumLayout { BFC, BCF, MIX };
+
+using DimType = EinsumHelper::DimensionType;
+constexpr auto kBatch = DimType::kBatch;
+constexpr auto kFree = DimType::kFree;
+constexpr auto kContract = DimType::kContract;
+
+// Describes an operand: input shape, number of batch, free and contract
+// dimensions, and the permutation that is needed to bring it to a matmul
+// compatible form.
+class EinsumDescriptor {
+ private:
+  // Checks whether input_labels[offset:offset+m] matches labels from other.
+  static bool OrderMatches(const Labels& input_labels, int offset, int m,
+                           EinsumHelper::DimensionType dim_type,
+                           const std::unique_ptr<EinsumDescriptor>& other) {
+    if (other == nullptr) {
+      return true;
+    }
+    int offset_other = 0;
+    if (dim_type == kFree) {
+      offset = other->offset_f;
+    } else if (dim_type == kContract) {
+      offset = other->offset_c;
+    }
+    return std::equal(input_labels.begin() + offset,
+                      input_labels.begin() + offset + m,
+                      other->permuted_labels.begin() + offset_other);
+  };
+
+  using label_t_iterator =
+      std::vector<EinsumHelper::DimensionType>::const_iterator;
+  static int32_t CountLabels(label_t_iterator begin, label_t_iterator end,
+                             EinsumHelper::DimensionType val) {
+    return static_cast<int32_t>(std::count_if(
+        begin, end, [val](EinsumHelper::DimensionType t) { return t == val; }));
+  };
+
+  // Appends indices to the "permute" vector where types maches value.
+  void AppendMatchingIndicesToPermute(
+      const std::vector<EinsumHelper::DimensionType>& types,
+      EinsumHelper::DimensionType val) {
+    for (int i = 0; i < types.size(); i++) {
+      if (types[i] == val) {
+        permute.push_back(i);
+      }
+    }
+  }
+
+  Status DetermineLayout(const Labels& input_labels,
+                         const std::vector<EinsumHelper::DimensionType>& types,
+                         const std::unique_ptr<EinsumDescriptor>& other) {
+    // Check if the current layout is BFC or BCF. In that case we could avoid
+    // transpose.
+    layout = EinsumLayout::MIX;
+    if (CountLabels(types.begin(), types.begin() + b, kBatch) == b &&
+        OrderMatches(input_labels, 0, b, kBatch, other)) {
+      // Batch dims are the leading dims. They have the same order as other.
+      if (CountLabels(types.begin() + b, types.begin() + b + f, kFree) == f) {
+        // All the free dims are placed consecutively after the batch dims.
+        // Their order is arbitrary. The final transpose will ensure that the
+        // output has correct order. We still have to check that the contract
+        // indices have correct order.
+        if (OrderMatches(input_labels, b + f, c, kContract, other)) {
+          layout = EinsumLayout::BFC;
+        }
+      } else if (CountLabels(types.begin() + b, types.begin() + b + c,
+                             kContract) == c) {
+        // All the contract dims are placed consecutively after the batch
+        // dims. Check whether the contract dims have the same order as the
+        // contract dims in other.
+        if (OrderMatches(input_labels, b, c, kContract, other)) {
+          layout = EinsumLayout::BCF;
+        }
+      }
+    }
+    return Status::OK();
+  }
+
+  Status CalculateMixedLayoutPermutation(
+      const EinsumLayout preferred_layout, const Labels& input_labels,
+      const std::vector<EinsumHelper::DimensionType>& types,
+      const std::unique_ptr<EinsumDescriptor>& other) {
+    // Input label types are mixed. Calculate a permutation that maps them
+    // to the preferred layout (BCF or BFC).
+    layout = preferred_layout;
+    if (other == nullptr) {
+      AppendMatchingIndicesToPermute(types, kBatch);
+    } else {
+      TF_RETURN_IF_ERROR(
+          FindIndicesoOfAllValuesInSrc(/*values=*/
+                                       absl::MakeConstSpan(
+                                           other->permuted_labels.begin(),
+                                           other->b),
+                                       /*src=*/
+                                       absl::MakeConstSpan(input_labels.begin(),
+                                                           input_labels.size()),
+                                       /*indices=*/&permute));
+    }
+    if (layout == EinsumLayout::BFC) {
+      AppendMatchingIndicesToPermute(types, kFree);
+      if (other == nullptr) {
+        AppendMatchingIndicesToPermute(types, kContract);
+      } else {
+        TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
+            /*values=*/absl::MakeConstSpan(
+                other->permuted_labels.begin() + other->offset_c, other->c),
+            /*src=*/
+            absl::MakeConstSpan(input_labels.begin(), input_labels.size()),
+            /*indices=*/&permute));
+      }
+      return Status::OK();
+    }
+    if (other == nullptr) {
+      AppendMatchingIndicesToPermute(types, kContract);
+    } else {
+      TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
+          /*values=*/absl::MakeConstSpan(
+              other->permuted_labels.begin() + other->offset_c, other->c),
+          /*src=*/absl::MakeConstSpan(input_labels.begin(), input_labels.end()),
+          /*indices=*/&permute));
+    }
+    AppendMatchingIndicesToPermute(types, kFree);
+    return Status::OK();
+  }
+
+  Status Initialize(const TRT_TensorOrWeights& operand, Labels input_labels,
+                    std::vector<EinsumHelper::DimensionType>& label_types,
+                    EinsumLayout preferred_layout,
+                    const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
+    if (preferred_layout == EinsumLayout::MIX) {
+      return errors::Internal("Preferred einsum layout cannot be MIX");
+    }
+    // Map label indices to label types.
+    std::vector<EinsumHelper::DimensionType> types;  // Input label types.
+    std::transform(input_labels.begin(), input_labels.end(),
+                   std::back_inserter(types),
+                   [&label_types](int i) { return label_types.at(i); });
+
+    b = CountLabels(types.begin(), types.end(), kBatch);
+    f = CountLabels(types.begin(), types.end(), kFree);
+    c = CountLabels(types.begin(), types.end(), kContract);
+
+    if (c == 0 || f == 0) {
+      VLOG(2) << "Einsum equation needs to have at least one free and one "
+                 "contract dimension";
+      return errors::Unimplemented("No conversion for einsum equation.");
+    }
+
+    TF_RETURN_IF_ERROR(DetermineLayout(input_labels, types, other));
+    if (layout == EinsumLayout::MIX) {
+      TF_RETURN_IF_ERROR(CalculateMixedLayoutPermutation(
+          preferred_layout, input_labels, types, other));
+    }
+
+    if (layout == EinsumLayout::BFC) {
+      offset_f = b;
+      offset_c = f + b;
+    } else {
+      offset_f = b + c;
+      offset_c = b;
+    }
+
+    dims = operand.GetTrtDims();
+    for (int i = 0; i < b; i++) {
+      // Set unknown batch dims to zero. These dims will be used in reshape op,
+      // where zero is a special value for retaining the original dim size.
+      if (dims.d[i] == -1) {
+        dims.d[i] = 0;
+      }
+    }
+    permuted_labels = input_labels;
+    if (!permute.empty()) {
+      // Apply the permutation on the dimension array.
+      nvinfer1::Dims orig_dims = dims;
+      for (int i = 0; i < permute.size(); i++) {
+        dims.d[i] = orig_dims.d[permute[i]];
+        permuted_labels[i] = input_labels[permute[i]];
+      }
+    }
+    size_tensors.resize(dims.nbDims, nullptr);
+    return Status::OK();
+  }
+
+ public:
+  EinsumDescriptor() : b(0), f(0), c(0) {}
+
+  // Deduces the number of batch, free, contract dimensions from the input
+  // labels, decides what layout to use, and determines permutation indices for
+  // that layout.
+  static StatusOr<std::unique_ptr<EinsumDescriptor>> Create(
+      const TRT_TensorOrWeights& operand, Labels input_labels,
+      std::vector<EinsumHelper::DimensionType>& label_types,
+      EinsumLayout preferred_layout,
+      const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
+    auto desc = std::make_unique<EinsumDescriptor>();
+    TF_RETURN_IF_ERROR(desc->Initialize(operand, input_labels, label_types,
+                                        preferred_layout, other));
+    VLOG(2) << desc->DebugString();
+    return desc;
+  }
+
+  int NumBatchDims() const { return b; }
+  int NumContractDims() const { return c; }
+  int NumFreeDims() const { return f; }
+  int ContractDimOffset() const { return offset_c; }
+  const Labels& PermutedLabels() const { return permuted_labels; }
+
+  std::string DebugString() const {
+    return absl::StrCat("Descriptor with ",
+                        (layout == EinsumLayout::BFC ? "BFC" : "BCF"),
+                        " layout, b=", b, ", f=", f, ", c=", c);
+  }
+
+  // Returns whether the free and contract dimension have static shape.
+  bool HasStaticShape() const {
+    return !std::any_of(dims.d + b, dims.d + dims.nbDims,
+                        [](int k) { return k == -1; });
+  }
+
+  nvinfer1::Permutation GetPermutation() const {
+    nvinfer1::Permutation p;
+    std::copy(permute.begin(), permute.end(), p.order);
+    return p;
+  }
+
+  std::vector<int> PermuteVector() const { return permute; }
+
+  // Sets the "size_tensors" vector to be filled with scalar constant tensors
+  // representing the shape of the operand.
+  Status SetDynamicSize(TRTNetworkBuilder* builder,
+                        const TRT_TensorOrWeights& operand) {
+    TRT_ENSURE(operand.GetTrtDims().nbDims == dims.nbDims);
+    if (operand.is_weights()) {
+      // Generate constants for each dimension of the constant weight tensor's
+      // shape.
+      for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
+        StatusOr<nvinfer1::IConstantLayer*> size_tensor =
+            builder->Constant<int32_t>(dims.d[i], 1);
+        TRT_ENSURE_PTR_OK(size_tensor);
+        size_tensors[i] = (*size_tensor)->getOutput(0);
+      }
+      return Status::OK();
+    }
+
+    // If the operand is a dynamic tensor, compute the shape value dynamically.
+    StatusOr<nvinfer1::IShapeLayer*> shape_layer =
+        builder->Shape(operand.tensor()->trt_tensor());
+    TRT_ENSURE_PTR_OK(shape_layer);
+    nvinfer1::ITensor* shape = (*shape_layer)->getOutput(0);
+    for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
+      int idx = permute.empty() ? i : permute.at(i);
+      StatusOr<nvinfer1::ISliceLayer*> slice_layer =
+          builder->Slice(shape, {1, {idx}}, {1, {1}}, {1, {1}});
+      TRT_ENSURE_PTR_OK(slice_layer);
+      size_tensors[i] = (*slice_layer)->getOutput(0);
+    }
+    return Status::OK();
+  }
+
+  EinsumLayout layout;
+  int b;  // number of batch dims
+  int f;  // number of free dims
+  int c;  // number of conraction dims
+  int offset_f;
+  int offset_c;
+  nvinfer1::Dims dims;
+  std::vector<int> permute;
+  std::vector<ITensorProxyPtr> size_tensors;
+  Labels permuted_labels;
+};
+
+// Reshapes operand so that the free dimensions are combined into a single dim,
+// and the contract dimensions are combined into another single dim.
+Status GetEinsumNewDynamicShape(TRTNetworkBuilder* builder,
+                                const EinsumDescriptor& desc,
+                                ITensorProxyPtr* new_shape) {
+  std::vector<nvinfer1::ITensor*> size;
+  size.reserve(desc.b + 2);
+  absl::c_transform(absl::MakeSpan(desc.size_tensors).subspan(0, desc.b + 2),
+                    std::back_inserter(size),
+                    [](const ITensorProxyPtr x) { return x->trt_tensor(); });
+
+  int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
+  int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
+
+  std::vector<nvinfer1::ITensor*> size_tensors;
+  size_tensors.reserve(desc.size_tensors.size());
+  absl::c_transform(desc.size_tensors, std::back_inserter(size_tensors),
+                    [](const ITensorProxyPtr x) -> nvinfer1::ITensor* {
+                      return x->trt_tensor();
+                    });
+
+  StatusOr<nvinfer1::ILayer*> shape_vol = builder->CumulativeProd(
+      absl::MakeSpan(size_tensors).subspan(desc.offset_f, desc.f));
+  TRT_ENSURE_PTR_OK(shape_vol);
+  size[idx_f] = (*shape_vol)->getOutput(0);
+
+  shape_vol = builder->CumulativeProd(
+      absl::MakeSpan(size_tensors).subspan(desc.offset_c, desc.c));
+  TRT_ENSURE_PTR_OK(shape_vol);
+  size[idx_c] = (*shape_vol)->getOutput(0);
+  StatusOr<nvinfer1::IConcatenationLayer*> layer =
+      builder->Concat(size, /*axis=*/0);
+  TRT_ENSURE_PTR_OK(layer);
+  *new_shape = (*layer)->getOutput(0);
+  return Status::OK();
+}
+
+// Reshapes operand so that the free dimensions are combined into a single dim,
+// and the contract dimensions are combined into another single dim.
+Status GetEinsumNewStaticShape(const EinsumDescriptor& desc,
+                               nvinfer1::Dims* new_dims) {
+  // Copy the batch dims and append two additional dimensions.
+  DimsAdapter adap(
+      absl::MakeSpan(static_cast<const int32_t*>(desc.dims.d), desc.b));
+  adap.Append(1).Append(1);
+
+  // Combine free dims and contract dims.
+  int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
+  int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
+
+  // Find the volume of the free dimensions.
+  int64_t vol_f =
+      DimsAdapter(
+          absl::MakeSpan(
+              static_cast<const int32_t*>(desc.dims.d) + desc.offset_f, desc.f))
+          .Volume();
+
+  // Find the volume of the contracted dimensions.
+  int64_t vol_c =
+      DimsAdapter(
+          absl::MakeSpan(
+              static_cast<const int32_t*>(desc.dims.d) + desc.offset_c, desc.c))
+          .Volume();
+
+  adap.dim(idx_f) = vol_f;
+  adap.dim(idx_c) = vol_c;
+  *new_dims = adap.AsTrtDims();
+  return Status::OK();
+}
+
+StatusOr<TRT_TensorOrWeights> ConditionEinsumWeights(
+    TRTNetworkBuilder* builder, const TRT_TensorOrWeights& operand,
+    const EinsumDescriptor& desc, const bool need_transpose) {
+  TRT_ENSURE(operand.is_weights());
+  if (!need_transpose) {
+    // If we don't need to transpose, then the operand remains as a weights
+    // constant. In this case we also don't need a reshape.
+    TRT_ShapedWeights weights(operand.weights());
+    nvinfer1::Dims new_dims;
+    TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
+    TF_RETURN_IF_ERROR(weights.SetShape(new_dims));
+    return TRT_TensorOrWeights(weights);
+  }
+
+  // Let TensorRT handle constant folding where possible.
+  StatusOr<nvinfer1::IConstantLayer*> tensor = builder->WeightsToConstant(
+      operand.weights().GetTrtWeights(), operand.GetTrtDims());
+  TRT_ENSURE_PTR_OK(tensor);
+  return TRT_TensorOrWeights((*tensor)->getOutput(0));
+}
+
+// Builds a TRT shuffle operation for the given operand. Replaces operand with a
+// pointer to the shuffle output.
+Status ConditionEinsumTensor(TRTNetworkBuilder* builder,
+                             std::unique_ptr<TRT_TensorOrWeights>* operand,
+                             const EinsumDescriptor& desc,
+                             const bool need_transpose,
+                             const bool need_reshape) {
+  StatusOr<ShuffleBuilder> shuffle =
+      ShuffleBuilder::Create(builder, (*operand)->tensor()->trt_tensor());
+  TRT_ENSURE_OK(shuffle);
+
+  // Set new shape.
+  if (need_reshape) {
+    if (desc.HasStaticShape()) {
+      nvinfer1::Dims new_dims;
+      TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
+      shuffle->SetReshape(new_dims);
+    } else {
+      ITensorProxyPtr new_shape;
+      TF_RETURN_IF_ERROR(GetEinsumNewDynamicShape(&*builder, desc, &new_shape));
+      shuffle->SetReshape(new_shape->trt_tensor());
+    }
+  }
+
+  if (need_transpose) {
+    shuffle->SetFirstTranspose(desc.GetPermutation());
+  }
+
+  StatusOr<nvinfer1::ITensor*> shuffle_out = shuffle->Output();
+  TRT_ENSURE_PTR_OK(shuffle_out);
+  *operand = std::make_unique<TRT_TensorOrWeights>(*shuffle_out);
+  return Status::OK();
+}
+
+// Handles einsum operand conditioning for both constant and non-constant
+// inputs. This is supported using the ShuffleEinsumWeights and
+// ShuffleEinsumTensor routines.
+Status ConditionEinsumOperand(TRTNetworkBuilder* builder,
+                              std::unique_ptr<TRT_TensorOrWeights>* operand,
+                              const EinsumDescriptor& desc) {
+  bool need_reshape = (desc.f != 1 || desc.c != 1);
+  bool need_transpose = !desc.permute.empty();
+  LOG(INFO) << "Condition operand. Need reshape " << need_reshape
+            << " nned transpose " << need_transpose;
+
+  if ((*operand)->is_weights()) {
+    StatusOr<TRT_TensorOrWeights> result =
+        ConditionEinsumWeights(builder, **operand, desc, need_transpose);
+    TRT_ENSURE_OK(result);
+    *operand =
+        std::make_unique<TRT_TensorOrWeights>(result.ConsumeValueOrDie());
+  }
+
+  // If we didn't convert the operand to a tensor, we can return here.
+  if ((*operand)->is_weights()) {
+    return Status::OK();
+  }
+
+  TF_RETURN_IF_ERROR(ConditionEinsumTensor(builder, operand, desc,
+                                           need_transpose, need_reshape));
+
+  return Status::OK();
+}
+
+// Combines output dims/labels by copying batch and free dims/labels from input
+// A, and concatenating free values from input B.
+template <typename InputIterator, typename OutputIterator>
+void AssembleOutput(InputIterator begin_a, InputIterator begin_b,
+                    const EinsumDescriptor& desc_a,
+                    const EinsumDescriptor& desc_b, OutputIterator out) {
+  std::copy(begin_a, begin_a + desc_a.b, out);
+  begin_a += desc_a.offset_f;
+  std::copy(begin_a, begin_a + desc_a.f, out + desc_a.b);
+  begin_b += desc_b.offset_f;
+  std::copy(begin_b, begin_b + desc_b.f, out + desc_a.b + desc_a.f);
+}
+
+// Restores free dimensions and sets final index order. Consider C = A * B,
+// batched MatMul op, where A.shape = [B, x, k] and B.shape = [B, k, y]. Then
+// C.shape = [B, x, y]. Here B can denote multiple batch indices while x, y, k
+// are single indices. The original inputs to Einsum can have multiple free
+// indices. These were combined into a singe free dimension x and y, for example
+// x = f_a1 * f_a2 * f_a3, y = f_b1 * f_b2. This routine creates a shuffle layer
+// to expand x into and y the original free dims, e.g. C is reshaped to
+// [B, f_a1, f_a2, f_a3, f_b1, f_b2]. Finally, a permutation is applied to
+// transform the shape to the shape of the original Einsum output.
+Status ShuffleEinsumOutput(OpConverterParams* params, EinsumDescriptor desc_a,
+                           EinsumDescriptor desc_b,
+                           const std::vector<int>& permutation,
+                           ITensorProxyPtr* output) {
+  if (permutation.empty() && (desc_a.f == 1 && desc_b.f == 1)) {
+    return Status::OK();
+  }
+
+  nvinfer1::IShuffleLayer* layer =
+      params->converter->network()->addShuffle(*(*output)->trt_tensor());
+  TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
+  params->converter->SetLayerName(layer, params->node_def, "shuffle",
+                                  /*sub_op_instance=*/2);
+
+  int output_rank = desc_a.b + desc_a.f + desc_b.f;
+  if (desc_a.f != 1 || desc_b.f != 1) {
+    if (desc_a.HasStaticShape() && desc_b.HasStaticShape()) {
+      nvinfer1::Dims dims_out = {output_rank, {}};
+      AssembleOutput(desc_a.dims.d, desc_b.dims.d, desc_a, desc_b, dims_out.d);
+      layer->setReshapeDimensions(dims_out);
+    } else {
+      std::vector<ITensorProxyPtr> size_tensors(output_rank);
+      AssembleOutput(desc_a.size_tensors.begin(), desc_b.size_tensors.begin(),
+                     desc_a, desc_b, size_tensors.begin());
+      ITensorProxyPtr new_shape;
+      auto builder = TRTNetworkBuilder::Create(params->converter->network(),
+                                               params->weight_store);
+      TRT_ENSURE_OK(builder);
+      std::vector<nvinfer1::ITensor*> size_itensors;
+      absl::c_transform(size_tensors, std::back_inserter(size_itensors),
+                        [](auto x) { return x->trt_tensor(); });
+      StatusOr<nvinfer1::IConcatenationLayer*> concat =
+          builder->Concat(size_itensors, /*axis=*/0);
+      TRT_ENSURE_PTR_OK(concat);
+      new_shape = (*concat)->getOutput(0);
+      layer->setInput(1, *new_shape->trt_tensor());
+    }
+  }
+
+  if (!permutation.empty()) {
+    nvinfer1::Permutation p;
+    std::copy(permutation.begin(), permutation.end(), p.order);
+    layer->setSecondTranspose(p);
+  }
+  *output = layer->getOutput(0);
+  return Status::OK();
+}
+
+// Updates "final_transpose" according to the given descriptors and output
+// labels.
+StatusOr<std::vector<int>> GetOutputTranspose(
+    const EinsumDescriptor& descriptor_a, const EinsumDescriptor& descriptor_b,
+    Labels output_labels) {
+  // Get final transpose.
+  std::vector<int> final_transpose;
+  final_transpose.reserve(descriptor_a.b + descriptor_a.f + descriptor_b.f);
+  Labels matmul_output_labels(descriptor_a.b + descriptor_a.f + descriptor_b.f);
+  AssembleOutput(descriptor_a.permuted_labels.begin(),
+                 descriptor_b.permuted_labels.begin(), descriptor_a,
+                 descriptor_b, matmul_output_labels.begin());
+  TF_RETURN_IF_ERROR(
+      FindIndicesoOfAllValuesInSrc(/*values=*/
+                                   absl::MakeConstSpan(output_labels.begin(),
+                                                       output_labels.end()),
+                                   /*src=*/
+                                   absl::MakeConstSpan(
+                                       matmul_output_labels.begin(),
+                                       matmul_output_labels.end()),
+                                   /*indices=*/&final_transpose));
+  // Clear identity transpose.
+  bool identity_transpose = true;
+  for (int i = 0; i < final_transpose.size() && identity_transpose; i++) {
+    identity_transpose &= final_transpose.at(i) == i;
+  }
+  if (identity_transpose) {
+    final_transpose.clear();
+  }
+  return final_transpose;
+}
+
+// Prepares EinsumDescriptors after parsing the equation and determines the
+// final transpose.
+Status ParseEquation(const std::string& equation,
+                     std::unique_ptr<TRT_TensorOrWeights>* input_a,
+                     std::unique_ptr<TRT_TensorOrWeights>* input_b,
+                     std::unique_ptr<EinsumDescriptor>* descriptor_a,
+                     std::unique_ptr<EinsumDescriptor>* descriptor_b,
+                     std::vector<int>* final_transpose) {
+  VLOG(2) << "Einsum equation " << equation;
+  OperandLabels input_labels;
+  Labels output_labels;
+  std::vector<EinsumHelper::DimensionType> label_types;
+  OperandLabelCounts input_label_counts;
+  LabelCounts output_label_counts;
+  absl::InlinedVector<bool, 2> input_has_ellipsis;
+  bool output_has_ellipsis;
+  TF_RETURN_IF_ERROR(EinsumHelper::ParseEquation(
+      equation, &input_labels, &output_labels, &label_types,
+      &input_label_counts, &output_label_counts, &input_has_ellipsis,
+      &output_has_ellipsis));
+
+  if (input_has_ellipsis[0] || input_has_ellipsis[1] || output_has_ellipsis) {
+    // TODO(tfeher): Handle ellipsis like EinsumHelper::ProcessDimensions.
+    // Note: ProcessDimensions would introduce kBroadcasting labels, which we
+    // need to replace with kBatch before we call InitDescriptor.
+    VLOG(2) << "Ellipsis not yet supported";
+    return errors::Unimplemented("No conversion for einsum equation.");
+  }
+
+  if (absl::c_any_of(label_types, [](auto l) {
+        return l == EinsumHelper::DimensionType::kReduce ||
+               l == EinsumHelper::DimensionType::kBroadcasting;
+      })) {
+    VLOG(2) << "Einsum reductions not implemented";
+    return errors::Unimplemented("No conversion for einsum equation.");
+  }
+
+  auto no_duplicated_labels = [](const LabelCounts& label_counts) {
+    return absl::c_any_of(label_counts, [](int i) { return i > 1; });
+  };
+  if (no_duplicated_labels(input_label_counts[0]) ||
+      no_duplicated_labels(input_label_counts[1]) ||
+      no_duplicated_labels(output_label_counts)) {
+    VLOG(2) << "Einsum invalid label count";
+    return errors::Unimplemented("No conversion for einsum equation.");
+  }
+
+  if ((*input_a)->is_weights() && (*input_b)->is_tensor()) {
+    // We prefer to use FC layer, needs A as tensor and B as weight.
+    std::swap(*input_a, *input_b);
+    std::swap(input_labels[0], input_labels[1]);
+    std::swap(input_label_counts[0], input_label_counts[1]);
+  }
+
+  auto desc = EinsumDescriptor::Create(**input_a, input_labels[0], label_types,
+                                       EinsumLayout::BFC);
+  TF_RETURN_IF_ERROR(desc.status());
+  *descriptor_a = desc.ConsumeValueOrDie();
+
+  desc = EinsumDescriptor::Create(**input_b, input_labels[1], label_types,
+                                  EinsumLayout::BCF, *descriptor_a);
+  TF_RETURN_IF_ERROR(desc.status());
+  *descriptor_b = desc.ConsumeValueOrDie();
+
+  auto out_transpose =
+      GetOutputTranspose(**descriptor_a, **descriptor_b, output_labels);
+
+  TRT_ENSURE_OK(out_transpose)
+  *final_transpose = out_transpose.ConsumeValueOrDie();
+  return Status::OK();
+}
+
+class ConvertEinsum : public OpConverterBase<ConvertEinsum> {
+ public:
+  explicit ConvertEinsum(OpConverterParams* params)
+      : OpConverterBase<ConvertEinsum>(params) {}
+
+  static constexpr std::array<DataType, 3> AllowedDataTypes() {
+    return {DataType::DT_FLOAT, DataType::DT_HALF};
+  }
+
+  static constexpr std::array<InputArgSpec, 2> InputSpec() {
+    return {InputArgSpec::Create("input_a", TrtInputArg::kBoth),
+            InputArgSpec::Create("input_b", TrtInputArg::kBoth)};
+  }
+
+  Status Validate() {
+    const auto& inputs = params_->inputs;
+    if (params_->use_implicit_batch) {
+      return errors::Unimplemented(
+          "Einsum converter requires dynamic shape mode");
+    }
+
+    input_a = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
+    input_b = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
+
+    StatusOr<std::string> eq = GetAttrValue<std::string>("equation");
+    TRT_ENSURE_OK(eq);
+    TF_RETURN_IF_ERROR(ParseEquation(*eq, &input_a, &input_b, &descriptor_a,
+                                     &descriptor_b, &final_transpose));
+
+    return Status::OK();
+  }
+
+  Status Convert() {
+    auto builder = TRTNetworkBuilder::Create(params_->converter->network(),
+                                             params_->weight_store);
+    TRT_ENSURE_OK(builder);
+    TRT_ENSURE(input_a && input_b);
+    TRT_ENSURE(descriptor_a && descriptor_b);
+
+    // Populate the size_tensor vector in the descriptor.
+    TF_RETURN_IF_ERROR(descriptor_a->SetDynamicSize(&*builder, *input_a));
+    TF_RETURN_IF_ERROR(descriptor_b->SetDynamicSize(&*builder, *input_b));
+
+    // Condition the operands for lowering to matmul.
+    TF_RETURN_IF_ERROR(
+        ConditionEinsumOperand(&*builder, &input_a, *descriptor_a));
+    TF_RETURN_IF_ERROR(
+        ConditionEinsumOperand(&*builder, &input_b, *descriptor_b));
+
+    // Build the matmul implementation.
+    StatusOr<ITensorProxyPtr> result = ConvertMatMulImpl(
+        params_, *input_a, *input_b, descriptor_a->layout == EinsumLayout::BCF,
+        descriptor_b->layout == EinsumLayout::BFC);
+    TF_RETURN_IF_ERROR(result.status());
+    ITensorProxyPtr output = result.ValueOrDie();
+
+    // Reshape and permute the output.
+    TF_RETURN_IF_ERROR(ShuffleEinsumOutput(
+        params_, *descriptor_a, *descriptor_b, final_transpose, &output));
+    this->AddOutput(output);
+    return Status::OK();
+  }
+
+ private:
+  std::unique_ptr<TRT_TensorOrWeights> input_a{nullptr};
+  std::unique_ptr<TRT_TensorOrWeights> input_b{nullptr};
+  std::vector<int> final_transpose;
+  std::unique_ptr<EinsumDescriptor> descriptor_a{nullptr};
+  std::unique_ptr<EinsumDescriptor> descriptor_b{nullptr};
+};
+
+}  // namespace
+
+REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertEinsum>(),
+                                  "Einsum");
+#endif
+
+}  // namespace convert
+}  // namespace tensorrt
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h
index 0fb5572..e8b26a8 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h
@@ -220,6 +220,29 @@
     return layer;
   }
 
+  // Adds a sequence of elementwise multiplication operations to the network.
+  // The returned layer's output contains the cumulative elementwise product of
+  // all tensors in the input.
+  StatusOr<nvinfer1::ILayer*> CumulativeProd(
+      absl::Span<nvinfer1::ITensor*> inputs) noexcept {
+    TRT_ENSURE(!absl::c_any_of(
+        inputs, [](nvinfer1::ITensor* x) { return x == nullptr; }));
+    nvinfer1::ILayer* out = nullptr;
+    if (inputs.size() == 1) {
+      out = network_->addIdentity(*inputs[0]);
+      TRT_ENSURE(out != nullptr);
+      return out;
+    }
+    nvinfer1::ITensor* last = inputs[0];
+    for (int i = 1; i < inputs.size(); i++) {
+      StatusOr<nvinfer1::IElementWiseLayer*> mul = this->Mul(last, inputs[i]);
+      TRT_ENSURE_PTR_OK(mul);
+      out = *mul;
+      last = (*mul)->getOutput(0);
+    }
+    return out;
+  }
+
   // Adds a Constant layer whose output is a TensorRT shape tensor. The shape
   // tensor's size and values correspond to dim's nbDims and d[], respectively.
   StatusOr<nvinfer1::IConstantLayer*> ConstantShape(
@@ -339,6 +362,25 @@
     return layer;
   }
 
+  // Adds a TensorRT Concatenate operation to the network.
+  StatusOr<nvinfer1::IConcatenationLayer*> Concat(
+      absl::Span<nvinfer1::ITensor* const> inputs, const int axis) {
+    for (nvinfer1::ITensor* input : inputs) {
+      TRT_ENSURE(input);
+    }
+    nvinfer1::IConcatenationLayer* layer = network_->addConcatenation(
+        inputs.data(), static_cast<int32_t>(inputs.size()));
+    TRT_ENSURE(layer);
+    layer->setAxis(axis);
+    return layer;
+  }
+
+  // Adds a TensorRT Concatenate operation to the network.
+  StatusOr<nvinfer1::IConcatenationLayer*> Concat(
+      const std::vector<nvinfer1::ITensor*>& inputs, const int axis) {
+    return this->Concat(absl::MakeSpan(inputs), axis);
+  }
+
   // Adds a TensorRT Shape operation, which determines the runtime shape of the
   // input tensor, to the network.
   StatusOr<nvinfer1::IShapeLayer*> Shape(nvinfer1::ITensor* input) {
@@ -546,11 +588,59 @@
     return FindProducerOf(layer->getInput(input_idx));
   }
 
+  nvinfer1::INetworkDefinition* Network() { return network_; }
+
  private:
   nvinfer1::INetworkDefinition* const network_;
   TrtWeightStore* const weight_store_;
 };
 
+class ShuffleBuilder {
+ private:
+  explicit ShuffleBuilder(TRTNetworkBuilder* builder, nvinfer1::ITensor* input)
+      : builder_(builder) {
+    layer_ = builder->Network()->addShuffle(*input);
+  }
+
+ public:
+  static StatusOr<ShuffleBuilder> Create(TRTNetworkBuilder* builder,
+                                         nvinfer1::ITensor* input) {
+    TRT_ENSURE(builder != nullptr);
+    TRT_ENSURE(input != nullptr);
+    return ShuffleBuilder(builder, input);
+  }
+
+  ShuffleBuilder& SetReshape(const nvinfer1::Dims& dims) {
+    layer_->setReshapeDimensions(dims);
+    return *this;
+  }
+
+  ShuffleBuilder& SetReshape(nvinfer1::ITensor* shape) {
+    layer_->setInput(1, *shape);
+    return *this;
+  }
+
+  ShuffleBuilder& SetFirstTranspose(const nvinfer1::Permutation& perm) {
+    layer_->setFirstTranspose(perm);
+    return *this;
+  }
+
+  ShuffleBuilder& SetSecondTranspose(const nvinfer1::Permutation& perm) {
+    layer_->setSecondTranspose(perm);
+    return *this;
+  }
+
+  StatusOr<nvinfer1::ITensor*> Output() {
+    TRT_ENSURE(layer_ != nullptr);
+    TRT_ENSURE(layer_->getOutput(0) != nullptr);
+    return layer_->getOutput(0);
+  }
+
+ private:
+  TRTNetworkBuilder* builder_;
+  nvinfer1::IShuffleLayer* layer_;
+};
+
 }  // namespace convert
 }  // namespace tensorrt
 }  // namespace tensorflow