[XLA:GPU] Add support for folding transpose of batch dimension into GeMMs.

Modifies `MatrixDescriptor`s to use two strides (for leading dim and batch), allowing matrices with interleaved batches. Relaxes the dot layout on GPU and transpose folding rules to support anything that can be expressed using those two strides.

PiperOrigin-RevId: 447825167
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc
index a056898..8b6e0ef 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc
@@ -57,18 +57,24 @@
 // dimensions.
 struct MatrixDescriptor {
   Value data;
-  int64_t leading_dim_stride;
-  int64_t batch_stride;
   bool transpose;  // Whether this matrix needs to be transposed.
+  int64_t num_rows;
+  int64_t num_cols;
+  int64_t stride;
 };
 
 MatrixDescriptor GetMatrixDesc(const xla::gpu::MatrixLayout& layout,
                                Value data) {
+  // TODO(cjfj): Add support for batch not in most major physical dimension.
+  CHECK((layout.batch_stride == 0) ||
+        (layout.batch_stride == layout.num_rows * layout.num_cols));
+  bool transpose = layout.order != xla::gpu::MatrixLayout::Order::kColumnMajor;
   return {
       data,
-      layout.leading_dim_stride,
+      transpose,
+      transpose ? layout.num_cols : layout.num_rows,
+      transpose ? layout.num_rows : layout.num_cols,
       layout.batch_stride,
-      /*transpose=*/layout.order != xla::gpu::MatrixLayout::Order::kColumnMajor,
   };
 }
 
@@ -76,12 +82,11 @@
   matrix_desc.transpose = !matrix_desc.transpose;
 }
 
-void MakeBlasGemmCompatible(int64_t& m, int64_t& n, MatrixDescriptor& lhs,
-                            MatrixDescriptor& rhs, MatrixDescriptor& output) {
+void MakeBlasGemmCompatible(MatrixDescriptor& lhs, MatrixDescriptor& rhs,
+                            MatrixDescriptor& output) {
   // BLAS GeMM doesn't support transposed output, but we can use the identity:
   // C^T = (A @ B)^T = B^T @ A^T.
   if (output.transpose) {
-    std::swap(m, n);
     std::swap(lhs, rhs);
     lhs.transpose = !lhs.transpose;
     rhs.transpose = !rhs.transpose;
@@ -124,39 +129,39 @@
 template <class GemmOp>
 Value CreateTfrtOps(GemmOp op, typename GemmOp::Adaptor adaptor, Value chain,
                     Value stream, mlir::Type input_type, mlir::Type output_type,
-                    int64_t batch_size, int64_t m, int64_t n, int64_t k,
-                    const MatrixDescriptor& lhs, const MatrixDescriptor& rhs,
-                    const MatrixDescriptor& output, xla::complex128 alpha,
-                    double beta, ConversionPatternRewriter& rewriter) {
+                    int64_t batch_size, const MatrixDescriptor& lhs,
+                    const MatrixDescriptor& rhs, const MatrixDescriptor& output,
+                    xla::complex128 alpha, double beta,
+                    ConversionPatternRewriter& rewriter) {
   auto loc = op.getLoc();
   if (auto bias = GetBias(adaptor)) {
     chain = rewriter.create<tfrt::gpu::MemCopyOp>(loc, adaptor.output(), bias,
                                                   stream, chain);
   }
 
+  auto k_val = lhs.transpose ? lhs.num_rows : lhs.num_cols;
+
   const Type mlir_compute_type = MlirComputationType(output_type, rewriter);
 
-  auto m_ = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, m);
-  auto n_ = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, n);
-  auto k_ = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, k);
+  auto m = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, output.num_rows);
+  auto n = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, output.num_cols);
+  auto k = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, k_val);
 
   // Scale type must match compute type, except for complex types, where
   // it must match the output type
   const Type mlir_scale_type =
       output_type.isa<mlir::ComplexType>() ? output_type : mlir_compute_type;
 
-  auto alpha_ = MakeScalingFactorConstant(rewriter, loc, mlir_scale_type,
-                                          llvm::APFloat(alpha.real()),
-                                          llvm::APFloat(alpha.imag()));
-  auto beta_ = MakeScalingFactorConstant(
+  auto const_alpha = MakeScalingFactorConstant(rewriter, loc, mlir_scale_type,
+                                               llvm::APFloat(alpha.real()),
+                                               llvm::APFloat(alpha.imag()));
+  auto const_beta = MakeScalingFactorConstant(
       rewriter, loc, mlir_scale_type, llvm::APFloat(beta), llvm::APFloat(0.));
 
-  auto lda = rewriter.create<tfrt::compiler::ConstantI32Op>(
-      loc, lhs.leading_dim_stride);
-  auto ldb = rewriter.create<tfrt::compiler::ConstantI32Op>(
-      loc, rhs.leading_dim_stride);
-  auto ldc = rewriter.create<tfrt::compiler::ConstantI32Op>(
-      loc, output.leading_dim_stride);
+  auto lda = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, lhs.num_rows);
+  auto ldb = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, rhs.num_rows);
+  auto ldc =
+      rewriter.create<tfrt::compiler::ConstantI32Op>(loc, output.num_rows);
 
   tfrt::gpu::wrapper::BlasGemmAlgo algorithm = GetBlasGemmAlgoOrDefault(op);
   auto algo = rewriter.create<tfrt::gpu::BlasGemmAlgoOp>(loc, algorithm);
@@ -172,18 +177,18 @@
   const auto compute_type = MlirTypeToBlasComputeType(mlir_compute_type);
   if (batch_size != 1) {
     auto lhs_stride =
-        rewriter.create<tfrt::compiler::ConstantI64Op>(loc, lhs.batch_stride);
+        rewriter.create<tfrt::compiler::ConstantI64Op>(loc, lhs.stride);
     auto rhs_stride =
-        rewriter.create<tfrt::compiler::ConstantI64Op>(loc, rhs.batch_stride);
-    auto output_stride = rewriter.create<tfrt::compiler::ConstantI64Op>(
-        loc, output.batch_stride);
+        rewriter.create<tfrt::compiler::ConstantI64Op>(loc, rhs.stride);
+    auto output_stride =
+        rewriter.create<tfrt::compiler::ConstantI64Op>(loc, output.stride);
     auto batch =
         rewriter.create<tfrt::compiler::ConstantI32Op>(loc, batch_size);
     return rewriter
         .create<tfrt::gpu::BlasGemmBatchExOp>(
-            loc, chain.getType(), handle, stream, lhs_op, rhs_op, m_, n_, k_,
-            alpha_, lhs.data, input_data_type, lda, lhs_stride, rhs.data,
-            input_data_type, ldb, rhs_stride, beta_, output.data,
+            loc, chain.getType(), handle, stream, lhs_op, rhs_op, m, n, k,
+            const_alpha, lhs.data, input_data_type, lda, lhs_stride, rhs.data,
+            input_data_type, ldb, rhs_stride, const_beta, output.data,
             output_data_type, ldc, output_stride, batch, compute_type, algo,
             chain)
         .getResult();
@@ -191,10 +196,10 @@
 
   return rewriter
       .create<tfrt::gpu::BlasGemmOp>(
-          loc, chain.getType(), handle, stream, lhs_op, rhs_op, m_, n_, k_,
-          alpha_, lhs.data, input_data_type, lda, rhs.data, input_data_type,
-          ldb, beta_, output.data, output_data_type, ldc, compute_type, algo,
-          chain)
+          loc, chain.getType(), handle, stream, lhs_op, rhs_op, m, n, k,
+          const_alpha, lhs.data, input_data_type, lda, rhs.data,
+          input_data_type, ldb, const_beta, output.data, output_data_type, ldc,
+          compute_type, algo, chain)
       .getResult();
 }
 
@@ -218,20 +223,17 @@
   if (!config.ok())
     return rewriter.notifyMatchFailure(op, config.status().ToString());
 
-  int64_t m = config->output_layout.num_rows;
-  int64_t n = config->output_layout.num_cols;
-  int64_t k = config->lhs_layout.num_cols;
   MatrixDescriptor lhs = GetMatrixDesc(config->lhs_layout, adaptor.lhs());
   MatrixDescriptor rhs = GetMatrixDesc(config->rhs_layout, adaptor.rhs());
   MatrixDescriptor output =
       GetMatrixDesc(config->output_layout, adaptor.output());
   int64_t batch_size = config->output_layout.batch_size;
 
-  MakeBlasGemmCompatible(m, n, lhs, rhs, output);
+  MakeBlasGemmCompatible(lhs, rhs, output);
 
   return CreateTfrtOps(op, adaptor, chain, stream, get_element_type(op.lhs()),
-                       get_element_type(op.output()), batch_size, m, n, k, lhs,
-                       rhs, output, config->alpha, config->beta, rewriter);
+                       get_element_type(op.output()), batch_size, lhs, rhs,
+                       output, config->alpha, config->beta, rewriter);
 }
 
 template <class GemmOpType>
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fe491ea..850b9d8 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -5091,10 +5091,7 @@
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:util",
-        "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/core:lib",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ffba9bb..2191ff9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -556,12 +556,12 @@
   }
   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
   pipeline.AddPass<TransposeFolding>(
-      [&](const HloInstruction& dot, int64_t operand) -> StatusOr<bool> {
-        if (DotImplementationCanHandleTranspose(dot,
-                                                *target_machine_features)) {
-          return TransposeFolding::IsRowColumnTransposeDotOperand(dot, operand);
-        }
-        return false;
+      [&](const HloInstruction& dot,
+          const TransposeFolding::OperandIndices& candidate_operands) {
+        return DotImplementationCanHandleTranspose(dot,
+                                                   *target_machine_features)
+                   ? candidate_operands
+                   : TransposeFolding::OperandIndices{};
       },
       TransposeFolding::NeverFoldTranspose);
   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index fab2ffe..b50571d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -183,7 +183,13 @@
                           ParseAndReturnVerifiedModule(hlo_string));
   HloComputation* computation = module->entry_computation();
 
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
+  TransposeFolding transpose_folding(
+      [](const HloInstruction& dot,
+         const TransposeFolding::OperandIndices& candidate_operands) {
+        return candidate_operands;
+      },
+      TransposeFolding::NeverFoldTranspose);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
   ASSERT_TRUE(changed);
   ASSERT_THAT(computation->root_instruction(),
               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
@@ -207,7 +213,13 @@
                           ParseAndReturnVerifiedModule(hlo_string));
   HloComputation* computation = module->entry_computation();
 
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
+  TransposeFolding transpose_folding(
+      [](const HloInstruction& dot,
+         const TransposeFolding::OperandIndices& candidate_operands) {
+        return candidate_operands;
+      },
+      TransposeFolding::NeverFoldTranspose);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
   ASSERT_TRUE(changed);
   ASSERT_THAT(computation->root_instruction(),
               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
@@ -232,7 +244,13 @@
                           ParseAndReturnVerifiedModule(hlo_string));
   HloComputation* computation = module->entry_computation();
 
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
+  TransposeFolding transpose_folding(
+      [](const HloInstruction& dot,
+         const TransposeFolding::OperandIndices& candidate_operands) {
+        return candidate_operands;
+      },
+      TransposeFolding::NeverFoldTranspose);
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
   ASSERT_TRUE(changed);
   ASSERT_THAT(computation->root_instruction(),
               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d33c468..6203c44 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1575,7 +1575,6 @@
         ":gpu_reduce_scatter_creator",
         ":gpu_sanitize_constant_names",
         ":gpu_scatter_expander",
-        ":matmul_utils",
         "@llvm-project//mlir:FuncDialect",
         "//tensorflow/compiler/xla/service/spmd:stateful_rng_spmd_partitioner",
         ":gpu_hlo_cost_analysis",
@@ -1963,7 +1962,6 @@
     deps = [
         ":backend_configs_cc",
         ":ir_emission_utils",
-        ":matmul_utils",
         ":stream_executor_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
@@ -1974,8 +1972,6 @@
         "//tensorflow/compiler/xla/service:layout_assignment",
         "//tensorflow/core:lib",
         "//tensorflow/core/platform:stream_executor_no_cuda",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
@@ -1996,7 +1992,6 @@
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # build_cleaner: keep
-        "//tensorflow/core/platform:status_matchers",
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
index 8fd416c..6420f5f 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
@@ -133,9 +133,6 @@
   TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
                       get_initialized_buffer(instr));
 
-  int64_t m = config.output_layout.num_rows;
-  int64_t n = config.output_layout.num_cols;
-  int64_t k = config.lhs_layout.num_cols;
   se::blas::MatrixDescriptor lhs = GetMatrixDesc(config.lhs_layout, lhs_buffer);
   se::blas::MatrixDescriptor rhs = GetMatrixDesc(config.rhs_layout, rhs_buffer);
   se::blas::MatrixDescriptor output =
@@ -143,7 +140,7 @@
   int64_t batch_size = config.output_layout.batch_size;
 
   // TODO(cjfj): Support transposed output when using cuBLASLt.
-  MakeBlasGemmCompatible(m, n, lhs, rhs, output);
+  MakeBlasGemmCompatible(lhs, rhs, output);
 
   TF_ASSIGN_OR_RETURN(
       tensorflow::DataType dtype,
@@ -152,6 +149,11 @@
   int device_id = stream->parent()->device_ordinal();
   bool trans_x = lhs.transpose == se::blas::Transpose::kTranspose;
   bool trans_y = rhs.transpose == se::blas::Transpose::kTranspose;
+
+  int64_t m = output.num_rows;
+  int64_t n = output.num_cols;
+  int64_t k = lhs.reduced_dim();
+
   bool broadcast = batch_size == 1;
 
   VLOG(4) << "matmul params: trans_x " << trans_x << " trans_y " << trans_y
@@ -167,8 +169,8 @@
 
   TF_ASSIGN_OR_RETURN(
       const se::blas::PlanAndAlgorithms* plan_and_algorithms,
-      se::GetPlanAndAlgorithms(stream, matmul_parameters, batch_size, m, n, k,
-                               dtype, lhs, rhs, output));
+      se::GetPlanAndAlgorithms(stream, matmul_parameters, batch_size, dtype,
+                               lhs, rhs, output));
 
   const std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>&
       algorithms = plan_and_algorithms->algorithms;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 3235fa5..f646d58 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -108,8 +108,7 @@
 
 template <typename Input, typename Output>
 static Status DoGemmWithAlgorithm(
-    int64_t batch_size, int64_t m, int64_t n, int64_t k,
-    const se::blas::MatrixDescriptor &lhs,
+    int64_t batch_size, const se::blas::MatrixDescriptor &lhs,
     const se::blas::MatrixDescriptor &rhs,
     const se::blas::MatrixDescriptor &output, Output alpha, Output beta,
     se::Stream *stream, se::blas::AlgorithmType algorithm,
@@ -122,23 +121,29 @@
 
   if (batch_size != 1) {
     return stream->ThenBlasGemmStridedBatchedWithAlgorithm(
-        lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
-        lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
-        rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
-        output.leading_dim_stride, output.batch_stride, batch_size,
+        lhs.transpose, rhs.transpose, output.num_rows, output.num_cols,
+        /*size of reduce dim=*/lhs.reduced_dim(),
+        /*alpha=*/alpha, lhs.cast<Input>(),
+        /*leading dim of LHS=*/lhs.num_rows, lhs.stride, rhs.cast<Input>(),
+        /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
+        /*beta=*/beta, &output_data,
+        /*leading dim of output=*/output.num_rows, output.stride, batch_size,
         computation_type, algorithm, output_profile_result);
   } else {
     return stream->ThenBlasGemmWithAlgorithm(
-        lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
-        lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
-        &output_data, output.leading_dim_stride, computation_type, algorithm,
+        lhs.transpose, rhs.transpose, output.num_rows, output.num_cols,
+        /*size of reduce dim=*/lhs.reduced_dim(),
+        /*alpha=*/alpha, lhs.cast<Input>(),
+        /*lda=*/lhs.num_rows, rhs.cast<Input>(),
+        /*ldb=*/rhs.num_rows,
+        /*beta=*/beta, &output_data,
+        /*ldc=*/output.num_rows, computation_type, algorithm,
         output_profile_result);
   }
 }
 
 template <typename Input>
-static Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
-                     const se::blas::MatrixDescriptor &lhs,
+static Status DoGemm(int64_t batch_size, const se::blas::MatrixDescriptor &lhs,
                      const se::blas::MatrixDescriptor &rhs,
                      const se::blas::MatrixDescriptor &output, Input alpha,
                      Input beta, se::Stream *stream,
@@ -148,29 +153,35 @@
   se::DeviceMemory<Input> output_data(output.data);
 
   if (algorithm) {
-    return DoGemmWithAlgorithm<Input, Input>(batch_size, m, n, k, lhs, rhs,
-                                             output, alpha, beta, stream,
-                                             *algorithm, output_profile_result);
+    return DoGemmWithAlgorithm<Input, Input>(batch_size, lhs, rhs, output,
+                                             alpha, beta, stream, *algorithm,
+                                             output_profile_result);
   }
 
   if (batch_size != 1) {
     return stream->ThenBlasGemmStridedBatched(
-        lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
-        lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
-        rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
-        output.leading_dim_stride, output.batch_stride, batch_size);
+        lhs.transpose, rhs.transpose, output.num_rows, output.num_cols,
+        /*size of reduce dim=*/lhs.reduced_dim(),
+        /*alpha=*/alpha, lhs.cast<Input>(),
+        /*leading dim of LHS=*/lhs.num_rows, lhs.stride, rhs.cast<Input>(),
+        /*leading dim of RHS=*/rhs.num_rows, rhs.stride,
+        /*beta=*/beta, &output_data,
+        /*leading dim of output=*/output.num_rows, output.stride, batch_size);
   }
 
-  return stream->ThenBlasGemm(lhs.transpose, rhs.transpose, m, n, k, alpha,
-                              lhs.cast<Input>(), lhs.leading_dim_stride,
-                              rhs.cast<Input>(), rhs.leading_dim_stride, beta,
-                              &output_data, output.leading_dim_stride);
+  return stream->ThenBlasGemm(
+      lhs.transpose, rhs.transpose, output.num_rows, output.num_cols,
+      /*size of reduce dim=*/lhs.reduced_dim(),
+      /*alpha=*/alpha, lhs.cast<Input>(),
+      /*leading dim of LHS=*/lhs.num_rows, rhs.cast<Input>(),
+      /*leading dim of RHS=*/rhs.num_rows,
+      /*beta=*/beta, &output_data,
+      /*leading dim of output=*/output.num_rows);
 }
 
 template <typename Input>
 static Status DoGemmLt(
-    int64_t batch_size, int64_t m, int64_t n, int64_t k,
-    const se::blas::MatrixDescriptor &lhs,
+    int64_t batch_size, const se::blas::MatrixDescriptor &lhs,
     const se::blas::MatrixDescriptor &rhs,
     const se::blas::MatrixDescriptor &output, se::Stream *stream, Input alpha,
     Input beta, se::ScratchAllocator *scratch_allocator,
@@ -183,6 +194,10 @@
 
   bool trans_x = lhs.transpose == se::blas::Transpose::kTranspose;
   bool trans_y = rhs.transpose == se::blas::Transpose::kTranspose;
+
+  int64_t m = output.num_rows;
+  int64_t n = output.num_cols;
+  auto k = lhs.reduced_dim();
   bool broadcast = batch_size == 1;
   VLOG(2) << "matmul params: trans_x " << trans_x << " trans_y " << trans_y
           << " adj_x " << false << " adj_y " << false << " m " << m << " n "
@@ -195,8 +210,8 @@
 
   TF_ASSIGN_OR_RETURN(
       const se::blas::PlanAndAlgorithms *plan_and_algorithms,
-      GetPlanAndAlgorithms(stream, matmul_parameters, batch_size, m, n, k,
-                           dtype, lhs, rhs, output));
+      GetPlanAndAlgorithms(stream, matmul_parameters, batch_size, dtype, lhs,
+                           rhs, output));
 
   const std::unique_ptr<se::blas::IBlasLtMatmulPlan> &plan =
       plan_and_algorithms->plan;
@@ -254,9 +269,6 @@
                se::blas::ProfileResult *profile_result,
                absl::optional<se::blas::AlgorithmType> algorithm) {
   VLOG(2) << "Executing a GemmThunk";
-  int64_t m = config.output_layout.num_rows;
-  int64_t n = config.output_layout.num_cols;
-  int64_t k = config.lhs_layout.num_cols;
   se::blas::MatrixDescriptor lhs = GetMatrixDesc(config.lhs_layout, lhs_buffer);
   se::blas::MatrixDescriptor rhs = GetMatrixDesc(config.rhs_layout, rhs_buffer);
   se::blas::MatrixDescriptor output =
@@ -264,7 +276,7 @@
   int64_t batch_size = config.output_layout.batch_size;
 
   // TODO(cjfj): Support transposed output when using cuBLASLt.
-  MakeBlasGemmCompatible(m, n, lhs, rhs, output);
+  MakeBlasGemmCompatible(lhs, rhs, output);
 
   // The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
   if (config.use_cublaslt && stream->parent()->SupportsBlasPlans()) {
@@ -275,30 +287,30 @@
     switch (config.output_layout.dtype) {
       case F16:
         return DoGemmLt<Eigen::half>(
-            batch_size, m, n, k, lhs, rhs, output, stream,
+            batch_size, lhs, rhs, output, stream,
             static_cast<Eigen::half>(config.alpha.real()),
             static_cast<Eigen::half>(config.beta), scratch_allocator, algorithm,
             /*output_profile_result=*/profile_result);
       case F32:
-        return DoGemmLt<float>(batch_size, m, n, k, lhs, rhs, output, stream,
+        return DoGemmLt<float>(batch_size, lhs, rhs, output, stream,
                                static_cast<float>(config.alpha.real()),
                                static_cast<float>(config.beta),
                                scratch_allocator, algorithm,
                                /*output_profile_result=*/profile_result);
       case F64:
-        return DoGemmLt<double>(batch_size, m, n, k, lhs, rhs, output, stream,
+        return DoGemmLt<double>(batch_size, lhs, rhs, output, stream,
                                 static_cast<double>(config.alpha.real()),
                                 config.beta, scratch_allocator, algorithm,
                                 /*output_profile_result=*/profile_result);
       case C64:
-        return DoGemmLt<complex64>(batch_size, m, n, k, lhs, rhs, output,
-                                   stream, static_cast<complex64>(config.alpha),
+        return DoGemmLt<complex64>(batch_size, lhs, rhs, output, stream,
+                                   static_cast<complex64>(config.alpha),
                                    static_cast<complex64>(config.beta),
                                    scratch_allocator, algorithm,
                                    /*output_profile_result=*/profile_result);
       case C128:
         return DoGemmLt<complex128>(
-            batch_size, m, n, k, lhs, rhs, output, stream, config.alpha,
+            batch_size, lhs, rhs, output, stream, config.alpha,
             static_cast<complex64>(config.beta), scratch_allocator, algorithm,
             /*output_profile_result=*/profile_result);
       default:
@@ -313,43 +325,40 @@
       case S32:
         if (!algorithm) algorithm = se::blas::kDefaultGemmAlgo;
         return DoGemmWithAlgorithm<int8_t, int32_t>(
-            batch_size, m, n, k, lhs, rhs, output,
+            batch_size, lhs, rhs, output,
             static_cast<int32_t>(config.alpha.real()),
             static_cast<int32_t>(config.beta), stream, *algorithm,
             /*output_profile_result=*/profile_result);
       case F16:
         return DoGemm<Eigen::half>(
-            batch_size, m, n, k, lhs, rhs, output,
+            batch_size, lhs, rhs, output,
             static_cast<Eigen::half>(config.alpha.real()),
             static_cast<Eigen::half>(config.beta), stream, algorithm,
             /*output_profile_result=*/profile_result);
       case BF16:
         return DoGemm<Eigen::bfloat16>(
-            batch_size, m, n, k, lhs, rhs, output,
+            batch_size, lhs, rhs, output,
             static_cast<Eigen::bfloat16>(config.alpha.real()),
             static_cast<Eigen::bfloat16>(config.beta), stream, algorithm,
             /*output_profile_result=*/profile_result);
       case F32:
-        return DoGemm<float>(batch_size, m, n, k, lhs, rhs, output,
-                             config.alpha.real(), config.beta, stream,
-                             algorithm,
+        return DoGemm<float>(batch_size, lhs, rhs, output, config.alpha.real(),
+                             config.beta, stream, algorithm,
                              /*output_profile_result=*/profile_result);
       case F64:
-        return DoGemm<double>(batch_size, m, n, k, lhs, rhs, output,
-                              config.alpha.real(), config.beta, stream,
-                              algorithm,
+        return DoGemm<double>(batch_size, lhs, rhs, output, config.alpha.real(),
+                              config.beta, stream, algorithm,
                               /*output_profile_result=*/profile_result);
       case C64:
-        return DoGemm<complex64>(batch_size, m, n, k, lhs, rhs, output,
-                                 static_cast<complex64>(config.alpha),
-                                 static_cast<complex64>(config.beta), stream,
-                                 algorithm,
-                                 /*output_profile_result=*/profile_result);
-      case C128:
-        return DoGemm<complex128>(
-            batch_size, m, n, k, lhs, rhs, output, config.alpha,
-            static_cast<complex128>(config.beta), stream, algorithm,
+        return DoGemm<complex64>(
+            batch_size, lhs, rhs, output, static_cast<complex64>(config.alpha),
+            static_cast<complex64>(config.beta), stream, algorithm,
             /*output_profile_result=*/profile_result);
+      case C128:
+        return DoGemm<complex128>(batch_size, lhs, rhs, output, config.alpha,
+                                  static_cast<complex128>(config.beta), stream,
+                                  algorithm,
+                                  /*output_profile_result=*/profile_result);
       default:
         return InternalError("Unexpected GEMM dtype: %s",
                              primitive_util::LowercasePrimitiveTypeName(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index e74aae2..a7ffef2 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -22,7 +22,6 @@
 #include <iterator>
 #include <string>
 #include <utility>
-#include <vector>
 
 #include "absl/memory/memory.h"
 #include "absl/strings/numbers.h"
@@ -105,7 +104,6 @@
 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
-#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/metrics.h"
 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
@@ -497,7 +495,7 @@
       pipeline.AddPass<HloConstantFolding>();
       pipeline.AddPass<ConditionalSimplifier>();
       pipeline.AddPass<RealImagExpander>();
-      pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot);
+      pipeline.AddPass<TransposeFolding>();
       pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
       pipeline.AddPass<HloDCE>();
     }();
@@ -716,8 +714,13 @@
   // GemmRewriter assumes that all transposes are folded into gemms, but,
   // since commit 7d529df, this is not always true at this point.
   // Therefore, rerun transpose folding.
-  pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot,
-                                     TransposeFolding::NeverFoldTranspose);
+  pipeline.AddPass<TransposeFolding>(
+      [](const HloInstruction& dot,
+         const TransposeFolding::OperandIndices& candidate_operands) {
+        return IsMatrixMultiplication(dot) ? candidate_operands
+                                           : TransposeFolding::OperandIndices{};
+      },
+      TransposeFolding::NeverFoldTranspose);
   // Rewrite GEMMs into custom calls.
   pipeline.AddPass<GemmRewriter>();
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index fd9dd7e..a522326 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -17,12 +17,9 @@
 
 #include <memory>
 
-#include "absl/algorithm/container.h"
-#include "absl/types/span.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -240,56 +237,51 @@
     CHECK(!IsCublasGemm(*instruction))
         << "Gemm rewriting should run after layout assignment";
 
+    // For unbatched S8xS8->S32 matrix multiplication enforce a TN layout, which
+    // will allow the NVidia GPUs to use TensorCores.
     if (IsMatrixMultiplication(*instruction)) {
       Shape output_shape = instruction->shape();
-      const Shape& lhs_shape = instruction->operand(0)->shape();
-      const Shape& rhs_shape = instruction->operand(1)->shape();
-      const DotDimensionNumbers& dot_dims =
-          instruction->dot_dimension_numbers();
+      Shape p1_shape = instruction->operand(0)->shape();
+      Shape p2_shape = instruction->operand(1)->shape();
+      if (output_shape.element_type() == PrimitiveType::S32 &&
+          p1_shape.element_type() == PrimitiveType::S8 &&
+          p2_shape.element_type() == PrimitiveType::S8 &&
+          output_shape.dimensions_size() == 2 &&
+          p1_shape.dimensions_size() == 2 && p2_shape.dimensions_size() == 2) {
+        LayoutUtil::SetToDefaultLayout(&p1_shape);
+        SetFortranLayout(&p2_shape);
+        LayoutUtil::SetToDefaultLayout(&output_shape);
+        TF_RETURN_IF_ERROR(SetOperandLayout(p1_shape, instruction, 0));
+        TF_RETURN_IF_ERROR(SetOperandLayout(p2_shape, instruction, 1));
+        TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
+        continue;
+      }
+    }
 
-      // Matmuls require the batch dimensions to be in consecutive physical
-      // dimensions and likewise for the contracting and non-contracting
-      // dimensions. Additionally, no batch dimension can be in the most
-      // minor physical dimension for inputs or the output.
-      absl::Span<const int64_t> lhs_batch_dims =
-          dot_dims.lhs_batch_dimensions();
-      absl::Span<const int64_t> lhs_col_dims =
-          dot_dims.lhs_contracting_dimensions();
-      TF_ASSIGN_OR_RETURN(
-          std::vector<int64_t> lhs_row_dims,
-          GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
-
-      absl::Span<const int64_t> rhs_batch_dims =
-          dot_dims.rhs_batch_dimensions();
-      absl::Span<const int64_t> rhs_row_dims =
-          dot_dims.rhs_contracting_dimensions();
-      TF_ASSIGN_OR_RETURN(
-          std::vector<int64_t> rhs_col_dims,
-          GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
-
-      // For unbatched S8xS8->S32 matrix multiplication enforce a TN layout,
-      // which will allow the NVidia GPUs to use TensorCores.
-      bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
-                           lhs_shape.element_type() == PrimitiveType::S8 &&
-                           rhs_shape.element_type() == PrimitiveType::S8 &&
-                           output_shape.dimensions_size() == 2 &&
-                           lhs_shape.dimensions_size() == 2 &&
-                           rhs_shape.dimensions_size() == 2);
-
-      if (is_s8_to_s32) {
-        TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
-            instruction, 0, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
-        TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
-            instruction, 1, rhs_batch_dims, rhs_col_dims, rhs_row_dims));
-      } else {
-        TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
-                                               lhs_row_dims, lhs_col_dims));
-        TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
-                                               rhs_row_dims, rhs_col_dims));
+    // For batched dot we require the default layout.
+    // TODO(b/112111608): This is overly conservative, the only real restriction
+    // is that batch dimensions must be major.
+    if (IsMatrixMultiplication(*instruction) &&
+        instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
+      // Verify that the batch dims come before the row and col dims.
+      DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers();
+      CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
+               dim_nums.rhs_batch_dimensions_size());
+      CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
+               instruction->shape().rank());
+      for (int64_t batch_dim : dim_nums.lhs_batch_dimensions()) {
+        CHECK_LT(batch_dim, instruction->shape().rank() - 2);
       }
 
-      // Dot output is implicitly ordered (batch dims, row dims, col dims).
+      // Set both inputs and the output to default layout.
+      Shape op0_shape = instruction->operand(0)->shape();
+      LayoutUtil::SetToDefaultLayout(&op0_shape);
+      Shape op1_shape = instruction->operand(1)->shape();
+      LayoutUtil::SetToDefaultLayout(&op1_shape);
+      Shape output_shape = instruction->shape();
       LayoutUtil::SetToDefaultLayout(&output_shape);
+      TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
+      TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
       TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
     } else if (instruction->opcode() == HloOpcode::kFft) {
       // cuFFT requires a dim0 major layout.
@@ -365,44 +357,5 @@
   return Status::OK();
 }
 
-Status GpuLayoutAssignment::SetDotOperandLayout(
-    HloInstruction* instruction, int64_t operand,
-    absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
-    absl::Span<const int64_t> col_dims) {
-  Shape shape = instruction->operand(operand)->shape();
-
-  // First, try to use the existing layout, if present.
-  if (shape.has_layout() &&
-      MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
-    // Re-set the operand layout, so it becomes mandatory.
-    return SetOperandLayout(shape, instruction, operand);
-
-  // Next, try the default layout (for the sake of everybody's sanity).
-  LayoutUtil::SetToDefaultLayout(&shape);
-  if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
-    return SetOperandLayout(shape, instruction, operand);
-
-  // Otherwise, fallback to forcing a (batch, rows, cols) layout.
-  return SetOperandBatchRowsColsLayout(instruction, operand, batch_dims,
-                                       row_dims, col_dims);
-}
-
-Status GpuLayoutAssignment::SetOperandBatchRowsColsLayout(
-    HloInstruction* instruction, int64_t operand,
-    absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
-    absl::Span<const int64_t> col_dims) {
-  std::vector<int64_t> major_to_minor;
-  major_to_minor.reserve(batch_dims.size() + row_dims.size() + col_dims.size());
-  major_to_minor.insert(major_to_minor.end(), batch_dims.begin(),
-                        batch_dims.end());
-  major_to_minor.insert(major_to_minor.end(), row_dims.begin(), row_dims.end());
-  major_to_minor.insert(major_to_minor.end(), col_dims.begin(), col_dims.end());
-
-  Shape shape = instruction->operand(operand)->shape();
-  *shape.mutable_layout() =
-      LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
-  return SetOperandLayout(shape, instruction, operand);
-}
-
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 882a2b5..39df595 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -45,17 +45,6 @@
   Status AddBackendConstraintsToDnnConvCustomCall(
       HloCustomCallInstruction* instr, LayoutConstraints* constraints);
 
-  Status SetOperandBatchRowsColsLayout(HloInstruction* instruction,
-                                       int64_t operand,
-                                       absl::Span<const int64_t> batch_dims,
-                                       absl::Span<const int64_t> row_dims,
-                                       absl::Span<const int64_t> col_dims);
-
-  Status SetDotOperandLayout(HloInstruction* instruction, int64_t operand,
-                             absl::Span<const int64_t> batch_dims,
-                             absl::Span<const int64_t> row_dims,
-                             absl::Span<const int64_t> col_dims);
-
   se::StreamExecutor* stream_executor_;
 };
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index a9916e4..5998949 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -30,7 +30,6 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/status_matchers.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
@@ -38,8 +37,6 @@
 namespace {
 
 namespace op = xla::testing::opcode_matchers;
-using ::tensorflow::testing::IsOkAndHolds;
-using ::testing::AllOf;
 
 using LayoutAssignmentTest = HloTestBase;
 
@@ -81,7 +78,7 @@
 
         GpuLayoutAssignment layout_assignment(
             &computation_layout, backend().default_stream_executor());
-        EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+        EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         for (const HloInstruction* operand : add->operands()) {
           EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
@@ -92,94 +89,33 @@
   }
 }
 
-TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,2,3]{1,2,0} parameter(0)
-    p1 = f32[5,3,4]{1,2,0} parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
-      lhs_batch_dims={0}, lhs_contracting_dims={2},
-      rhs_batch_dims={0}, rhs_contracting_dims={1}
-  })";
+// Returns a list shapes with all the possible layouts of this shape, including
+// a shape with no layout.
+std::vector<Shape> AllLayoutsOf(const Shape& s) {
+  std::vector<int64_t> layout_vec(s.dimensions_size());
+  std::iota(layout_vec.begin(), layout_vec.end(), 0);
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
+  std::vector<Shape> shapes;
+  shapes.push_back(s);
+  shapes.back().clear_layout();
 
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(&computation_layout,
-                                        backend().default_stream_executor());
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              AllOf(op::Dot(op::ShapeWithLayout("f32[5,2,3]{1,2,0}"),
-                            op::ShapeWithLayout("f32[5,3,4]{1,2,0}")),
-                    op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
+  do {
+    shapes.push_back(s);
+    *shapes.back().mutable_layout() = LayoutUtil::MakeLayout(layout_vec);
+  } while (std::next_permutation(layout_vec.begin(), layout_vec.end()));
+
+  return shapes;
 }
 
-TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
+TEST_F(LayoutAssignmentTest, DotLayout) {
   const char* hlo_text = R"(
   HloModule DotLayout
   ENTRY dot {
-    p0 = f32[5,3,2] parameter(0)
-    p1 = f32[5,4,3]{0,1,2} parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
-      lhs_batch_dims={0}, lhs_contracting_dims={1},
-      rhs_batch_dims={0}, rhs_contracting_dims={2}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(&computation_layout,
-                                        backend().default_stream_executor());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              AllOf(op::Dot(op::ShapeWithLayout("f32[5,3,2]{2,1,0}"),
-                            op::ShapeWithLayout("f32[5,4,3]{2,1,0}")),
-                    op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[2,3,5]{2,1,0} parameter(0)
-    p1 = f32[3,4,5] parameter(1)
-    ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
-      lhs_batch_dims={2}, lhs_contracting_dims={1},
-      rhs_batch_dims={2}, rhs_contracting_dims={0}
-  })";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(hlo_text));
-
-  ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape(),
-      /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(&computation_layout,
-                                        backend().default_stream_executor());
-
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Dot(op::ShapeWithLayout("f32[2,3,5]{1,0,2}"),
-                      op::ShapeWithLayout("f32[3,4,5]{1,0,2}")));
-}
-
-TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
-  const char* hlo_text = R"(
-  HloModule DotLayout
-  ENTRY dot {
-    p0 = f32[5,6,2,3] parameter(0)
-    p1 = f32[6,5,3,4] parameter(1)
-    ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
+    p0 = f32[8,8,256,64]{3,1,2,0} parameter(0)
+    p1 = f32[8,8,256,64]{3,1,2,0} parameter(1)
+    ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1),
       lhs_batch_dims={0,1}, lhs_contracting_dims={3},
-      rhs_batch_dims={1,0}, rhs_contracting_dims={2}
+      rhs_batch_dims={0,1}, rhs_contracting_dims={3}
   })";
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -190,20 +126,22 @@
       /*ignore_layouts=*/false);
   GpuLayoutAssignment layout_assignment(&computation_layout,
                                         backend().default_stream_executor());
+  EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  Shape expected_shape =
+      ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
   EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Dot(op::ShapeWithLayout("f32[5,6,2,3]{3,2,1,0}"),
-                      op::ShapeWithLayout("f32[6,5,3,4]{3,2,0,1}")));
+              op::Dot(op::ShapeWithLayout(expected_shape),
+                      op::ShapeWithLayout(expected_shape)));
 }
 
 TEST_F(LayoutAssignmentTest, DotLayoutS8) {
   const char* hlo_text = R"(
   HloModule DotLayout
   ENTRY int8_t {
-    p0 = s8[32,64] parameter(0)
-    p1 = s8[64,96] parameter(1)
-    ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    p0 = s8[1024,65536] parameter(0)
+    p1 = s8[65536,65536] parameter(1)
+    ROOT out = s32[1024,65536] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
   })";
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -214,11 +152,15 @@
       /*ignore_layouts=*/false);
   GpuLayoutAssignment layout_assignment(&computation_layout,
                                         backend().default_stream_executor());
+  EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  Shape expected_shape_p0 =
+      ShapeUtil::MakeShapeWithLayout(S8, {1024, 65536}, {1, 0});
+  Shape expected_shape_p1 =
+      ShapeUtil::MakeShapeWithLayout(S8, {65536, 65536}, {0, 1});
   EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Dot(op::ShapeWithLayout("s8[32,64]{1,0}"),
-                      op::ShapeWithLayout("s8[64,96]{0,1}")));
+              op::Dot(op::ShapeWithLayout(expected_shape_p0),
+                      op::ShapeWithLayout(expected_shape_p1)));
 }
 
 TEST_F(LayoutAssignmentTest, SortLayout) {
@@ -249,11 +191,12 @@
       /*ignore_layouts=*/false);
   GpuLayoutAssignment layout_assignment(&computation_layout,
                                         backend().default_stream_executor());
+  EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0});
   EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Sort(op::ShapeWithLayout("f32[3,2]{1,0}"),
-                       op::ShapeWithLayout("f32[3,2]{1,0}")));
+              op::Sort(op::ShapeWithLayout(expected_shape),
+                       op::ShapeWithLayout(expected_shape)));
 }
 
 TEST_F(LayoutAssignmentTest, FftLayout) {
@@ -274,12 +217,14 @@
       /*ignore_layouts=*/false);
   GpuLayoutAssignment layout_assignment(&computation_layout,
                                         backend().default_stream_executor());
+  EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
-  EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
+  Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
   EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Copy(op::Transpose(
-                  AllOf(op::Fft(op::ShapeWithLayout("c64[8,32]{1,0}")),
-                        op::ShapeWithLayout("c64[8,32]{1,0}")))));
+              op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape))));
+  EXPECT_THAT(
+      module->entry_computation()->root_instruction(),
+      op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape)))));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc
index 36a58c3..c8661ba 100644
--- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc
@@ -26,6 +26,7 @@
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h"
+#include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -74,38 +75,29 @@
                                        absl::Span<const int64_t> row_dims,
                                        absl::Span<const int64_t> col_dims) {
   TF_RET_CHECK(shape.has_layout());
-  TF_RET_CHECK(!row_dims.empty());
-  TF_RET_CHECK(!col_dims.empty());
 
+  // Start by classifying each physical dimension as batch, row, or column.
+  // This is O(rank**2), but we expect rank to be small.
   std::vector<int64_t> minor_to_major;
-  for (size_t i = 0; i < shape.rank();) {
-    // The GeMM output always has its layout set such that the batch, row, and
-    // col dim groups are each laid out physically sequentially. GeMM operands
-    // must, therefore, be laid out similarly.
-    auto check_physically_sequential = [&](absl::Span<const int64_t> dims) {
-      for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
-        // NOTE: `i` is incremented as we check the dimensions.
-        if (*it != shape.layout().minor_to_major()[i++])
-          return InvalidArgument("dims not physically sequential");
-      }
-      return Status::OK();
-    };
-
-    int64_t dim = shape.layout().minor_to_major()[i];
-    if (dim == row_dims.back()) {
-      minor_to_major.push_back(1);
-      TF_RETURN_IF_ERROR(check_physically_sequential(row_dims));
-    } else if (dim == col_dims.back()) {
-      minor_to_major.push_back(2);
-      TF_RETURN_IF_ERROR(check_physically_sequential(col_dims));
-    } else if (!batch_dims.empty() && (dim == batch_dims.back())) {
-      minor_to_major.push_back(0);
-      TF_RETURN_IF_ERROR(check_physically_sequential(batch_dims));
-    } else {
-      return InvalidArgument("dims not physically sequential");
-    }
+  minor_to_major.reserve(shape.rank());
+  for (int64_t dim : shape.layout().minor_to_major()) {
+    size_t batch_matches = absl::c_count(batch_dims, dim);
+    size_t row_matches = absl::c_count(row_dims, dim);
+    size_t col_matches = absl::c_count(col_dims, dim);
+    size_t total_matches = batch_matches + row_matches + col_matches;
+    TF_RET_CHECK(total_matches == 1) << "dimensions incomplete or overlapping";
+    minor_to_major.push_back(batch_matches ? 0 : row_matches ? 1 : 2);
   }
 
+  // Remove repeated items (e.g. `[0, 0, 2, 1, 1, 1]` -> `[0, 2, 1]`).
+  minor_to_major.erase(
+      std::unique(minor_to_major.begin(), minor_to_major.end()),
+      minor_to_major.end());
+
+  // In order to "collapse" the shape to 3D, each of the batch, row, and column
+  // dims must be in consecutive physical dimensions.
+  TF_RET_CHECK(minor_to_major.size() == (batch_dims.empty() ? 2 : 3));
+
   if (batch_dims.empty()) minor_to_major.push_back(0);
 
   auto dim_size = [&](absl::Span<const int64_t> dims) {
@@ -172,42 +164,6 @@
   return MatrixLayout::For(batch_row_col_shape);
 }
 
-StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
-                                              int64_t operand_idx) {
-  TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
-  TF_RET_CHECK(dot.operand_count() > operand_idx);
-
-  const HloInstruction& transpose = *dot.operand(operand_idx);
-  TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
-
-  const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
-
-  auto transposed = [&](const auto& dims) {
-    std::vector<int64_t> transposed_dims;
-    transposed_dims.reserve(dims.size());
-    for (int64_t dim : dims) {
-      transposed_dims.push_back(transpose.dimensions(dim));
-    }
-    return transposed_dims;
-  };
-
-  auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
-                                       : dot_dims.rhs_batch_dimensions();
-  auto contracting_dims = (operand_idx == 0)
-                              ? dot_dims.lhs_contracting_dimensions()
-                              : dot_dims.rhs_contracting_dimensions();
-  TF_ASSIGN_OR_RETURN(
-      std::vector<int64_t> non_contracting_dims,
-      GetNonContractingDims(transpose.shape(), batch_dims, contracting_dims));
-
-  // If we're able to construct a valid `MatrixLayout` for the transposed
-  // dimensions, then GeMM can support folding the transpose.
-  return MatrixLayout::For(transpose.operand(0)->shape(),
-                           transposed(batch_dims), transposed(contracting_dims),
-                           transposed(non_contracting_dims))
-      .ok();
-}
-
 namespace {
 
 bool IsBlasPlansCompatibleType(PrimitiveType type) {
@@ -365,24 +321,26 @@
 
 se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout,
                                          se::DeviceMemoryBase data) {
+  // TODO(cjfj): Add support for batch not in most major physical dimension.
+  CHECK((layout.batch_stride == 0) ||
+        (layout.batch_stride == layout.num_rows * layout.num_cols));
   bool transpose = layout.order != MatrixLayout::Order::kColumnMajor;
   return {
       data,
-      layout.leading_dim_stride,
-      layout.batch_stride,
       transpose ? se::blas::Transpose::kTranspose
                 : se::blas::Transpose::kNoTranspose,
+      transpose ? layout.num_cols : layout.num_rows,
+      transpose ? layout.num_rows : layout.num_cols,
+      layout.batch_stride,
   };
 }
 
-void MakeBlasGemmCompatible(int64_t& m, int64_t& n,
-                            se::blas::MatrixDescriptor& lhs,
+void MakeBlasGemmCompatible(se::blas::MatrixDescriptor& lhs,
                             se::blas::MatrixDescriptor& rhs,
                             se::blas::MatrixDescriptor& output) {
   // BLAS GeMM doesn't support transposed output, but we can use the identity:
   // C^T = (A @ B)^T = B^T @ A^T.
   if (output.transpose == se::blas::Transpose::kTranspose) {
-    std::swap(m, n);
     std::swap(lhs, rhs);
     TransposeMatrixDesc(lhs);
     TransposeMatrixDesc(rhs);
diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h
index b70045d..caf2d46 100644
--- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h
@@ -64,10 +64,6 @@
   int64_t batch_stride;  // `batch_stride` is set to `0` when `batch_size == 1`.
 };
 
-// GPU folding rule for the `TransposeFolding` pass.
-StatusOr<bool> CanFoldTransposeOperandIntoDot(const HloInstruction& dot,
-                                              int64_t operand_idx);
-
 struct GemmConfig {
   static StatusOr<GemmConfig> For(const HloInstruction* gemm);
   static StatusOr<GemmConfig> For(mlir::Operation* op, bool use_cublaslt);
@@ -92,8 +88,7 @@
 se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout,
                                          se::DeviceMemoryBase data);
 
-void MakeBlasGemmCompatible(int64_t& m, int64_t& n,
-                            se::blas::MatrixDescriptor& lhs,
+void MakeBlasGemmCompatible(se::blas::MatrixDescriptor& lhs,
                             se::blas::MatrixDescriptor& rhs,
                             se::blas::MatrixDescriptor& output);
 
diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils_test.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils_test.cc
index cb7f1d9..8917be3 100644
--- a/tensorflow/compiler/xla/service/gpu/matmul_utils_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/matmul_utils_test.cc
@@ -65,14 +65,14 @@
         {"f32[3,4]{1,0}", {}, {1}, {0}, "f32[1,4,3]{1,2,0}"},
         {"f32[3,4,5]{2,1,0}", {0}, {1}, {2}, "f32[3,4,5]{2,1,0}"},
         {"f32[3,4,5]{2,1,0}", {2}, {1}, {0}, "f32[5,4,3]{0,1,2}"},
-        {"f32[3,4,5,6,7,8]{5,2,4,1,3,0}",
+        {"f32[3,4,5,6,7,8]{2,5,1,4,0,3}",
          {0, 3},
          {1, 4},
          {2, 5},
          "f32[18,28,40]{2,1,0}"},
     }));
 
-TEST(GetBatchRowColumnShapeTest, BatchRowsColsInterleaved) {
+TEST(GetBatchRowColumnShapeTest, InvalidPhysicalLayout) {
   Shape shape = ParseShape("f32[3,4,5,6,7,8]{5,4,3,2,1,0}").ValueOrDie();
   auto result =
       GetBatchRowColumnShape(shape, /*batch_dims=*/{0, 3},
@@ -80,13 +80,6 @@
   EXPECT_FALSE(result.ok());
 }
 
-TEST(GetBatchRowColumnShapeTest, WrongPhysicalOrder) {
-  Shape shape = ParseShape("f32[3,4,5,6]{3,2,0,1}").ValueOrDie();
-  auto result = GetBatchRowColumnShape(shape, /*batch_dims=*/{0, 1},
-                                       /*row_dims=*/{2}, /*col_dims=*/{3});
-  EXPECT_FALSE(result.ok());
-}
-
 using Order = MatrixLayout::Order;
 
 struct GetMatrixLayoutTestParams {
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
index 21ed63d..f147855 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
@@ -196,55 +196,6 @@
       )");
 }
 
-TEST_F(GemmRewriteTest, BatchRowTransposeFoldCheck) {
-  const char* hlo_text = R"(
-HloModule BatchRowTransposeFoldCheck
-
-ENTRY AddDotsFunc {
-  x = f32[2,5,3] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2}
-  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0)
-; CHECK-NEXT:    [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[\"0\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
-      )");
-}
-
-TEST_F(GemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) {
-  const char* hlo_text = R"(
-HloModule BatchFromMinorDimTransposeDoesntFold
-
-ENTRY AddDotsFunc {
-  x = f32[3,2,5] parameter(0)
-  y = f32[5,3,4] parameter(1)
-  x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0}
-  ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
-}
-
-)";
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHlo(hlo_text,
-                    R"(
-
-; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] {
-; CHECK-NEXT:    [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0)
-; CHECK-DAG:     [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1)
-; CHECK-DAG:     [[COPY:%[^ ]+]] = f32[3,2,5]{0,1,2} copy([[P0]])
-; CHECK-NEXT:    ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[COPY]], [[P1]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}"
-      )");
-}
-
 TEST_F(GemmRewriteTest, InstrTransposeFoldCheck) {
   const char* hlo_text = R"(
 HloModule InstrTransposeFoldGemm
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index fdd4c99..8abf544 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -15,29 +15,73 @@
 
 #include "tensorflow/compiler/xla/service/transpose_folding.h"
 
-#include <algorithm>
-#include <utility>
 #include <vector>
 
-#include "absl/algorithm/container.h"
-#include "absl/types/span.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_instructions.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/status.h"
 
 namespace xla {
+
 namespace {
 
+TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
+    const HloInstruction& dot,
+    const TransposeFolding::TransposableGemmOperandsFn&
+        transposable_gemm_operands) {
+  if (HloOpcode::kDot != dot.opcode()) {
+    return {};
+  }
+
+  if (!absl::c_equal(dot.dot_dimension_numbers().lhs_batch_dimensions(),
+                     dot.dot_dimension_numbers().rhs_batch_dimensions())) {
+    return {};
+  }
+
+  int64_t num_batch_dims =
+      dot.dot_dimension_numbers().lhs_batch_dimensions_size();
+  int64_t expected_rank = 2 + num_batch_dims;
+  auto is_r2_transpose = [&](const HloInstruction& transpose) {
+    if (transpose.opcode() != HloOpcode::kTranspose) {
+      return false;
+    }
+    const auto& transpose_dims = transpose.dimensions();
+    if (transpose_dims.size() != expected_rank) {
+      return false;
+    }
+
+    // Check that the transpose doesn't touch any batch dimensions, but does
+    // transpose the non-batch ones.
+    for (int64_t i = 0; i != expected_rank; ++i) {
+      bool is_batch = absl::c_linear_search(
+          dot.dot_dimension_numbers().lhs_batch_dimensions(),
+          transpose_dims[i]);
+      if ((transpose_dims[i] == i) != is_batch) {
+        return false;
+      }
+    }
+    return true;
+  };
+
+  TransposeFolding::OperandIndices operand_set;
+  for (int64_t i = 0; i < dot.operand_count(); ++i) {
+    auto& operand = *dot.operand(i);
+    if (is_r2_transpose(operand)) {
+      operand_set.push_back(i);
+    } else if (operand.shape().rank() != expected_rank) {
+      return {};
+    }
+  }
+
+  return transposable_gemm_operands(dot, operand_set);
+}
+
 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
     const HloInstruction& convolution,
     const TransposeFolding::TransposableConvOperandsFn&
@@ -57,62 +101,51 @@
   return transposable_conv_operands(convolution, operand_set);
 }
 
-bool IsNonIdentityTranspose(const HloInstruction* instruction) {
-  if (instruction->opcode() == HloOpcode::kTranspose) {
-    for (int dim = 0; dim < instruction->dimensions().size(); ++dim) {
-      if (dim != instruction->dimensions(dim)) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-void TransposeDims(tensorflow::protobuf::RepeatedField<int64_t>& dims,
-                   absl::Span<const int64_t> transpose_dims) {
-  for (auto& dim : dims) {
-    dim = transpose_dims[dim];
-  }
-}
-
 using InstructionOperandsPair =
     std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
 
-// Folds the operands of `dot` that are foldable transposes.
-Status FoldTransposeIntoDot(InstructionOperandsPair& pair) {
+// Folds the operands of `dot` that are foldable transposes. `computation` is
+// the parent HLO computation of `dot`.
+Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
   HloInstruction* dot = pair.first;
 
-  DotDimensionNumbers new_dot_dims = dot->dot_dimension_numbers();
-  HloInstruction* lhs = dot->mutable_operand(0);
-  HloInstruction* rhs = dot->mutable_operand(1);
+  DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers();
+  HloInstruction* new_lhs = dot->mutable_operand(0);
+  HloInstruction* new_rhs = dot->mutable_operand(1);
+
+  CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
+  CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
 
   for (int64_t operand_index : pair.second) {
+    // We checked that the batch dimensions are not touched by the transpose,
+    // and shape inference guarantees that there is exactly one contracting
+    // dimension.
     if (operand_index == 0) {
-      TransposeDims(*new_dot_dims.mutable_lhs_contracting_dimensions(),
-                    lhs->dimensions());
-      TransposeDims(*new_dot_dims.mutable_lhs_batch_dimensions(),
-                    lhs->dimensions());
-      lhs = lhs->mutable_operand(0);
+      CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
+      new_dim_numbers.set_lhs_contracting_dimensions(
+          0,
+          new_lhs->dimensions(new_dim_numbers.lhs_contracting_dimensions(0)));
+      new_lhs = new_lhs->mutable_operand(0);
     } else {
       CHECK_EQ(operand_index, 1);
-      TransposeDims(*new_dot_dims.mutable_rhs_contracting_dimensions(),
-                    rhs->dimensions());
-      TransposeDims(*new_dot_dims.mutable_rhs_batch_dimensions(),
-                    rhs->dimensions());
-      rhs = rhs->mutable_operand(0);
+      CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
+      new_dim_numbers.set_rhs_contracting_dimensions(
+          0,
+          new_rhs->dimensions(new_dim_numbers.rhs_contracting_dimensions(0)));
+      new_rhs = new_rhs->mutable_operand(0);
     }
   }
 
-  return dot->parent()->ReplaceWithNewInstruction(
-      dot, HloInstruction::CreateDot(dot->shape(), lhs, rhs, new_dot_dims,
-                                     dot->precision_config()));
+  std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
+      dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
+  return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
 }
 
 // Folds the operands of `convolution` that are foldable transposes.
 // `computation` is the parent HLO computation of `convolution`.
 //
 // Returns whether the module is changed.
-bool FoldTransposeIntoConvolution(InstructionOperandsPair& pair) {
+bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
   auto& convolution = *pair.first;
   auto& operand_indices = pair.second;
 
@@ -183,51 +216,31 @@
 }  // namespace
 
 TransposeFolding::TransposeFolding(
-    CanFoldTransposeOperand dot_can_fold_transpose_operand,
+    TransposableGemmOperandsFn transposable_gemm_operands,
     TransposableConvOperandsFn transposable_conv_operands)
-    : dot_can_fold_transpose_operand_(
-          std::move(dot_can_fold_transpose_operand)),
+    : transposable_gemm_operands_(std::move(transposable_gemm_operands)),
       transposable_conv_operands_(std::move(transposable_conv_operands)) {}
 
 StatusOr<bool> TransposeFolding::Run(HloModule* module) {
   // Modifying the graph while traversing is dangerous, so we find all folding
   // opportunities before actually folding them.
-  std::vector<InstructionOperandsPair> foldable_dots;
-  std::vector<InstructionOperandsPair> foldable_convolutions;
-
+  std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
+  std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
   FunctionVisitor visit_fn([this, &foldable_dots, &foldable_convolutions](
                                HloInstruction* instruction) {
-    if (instruction->opcode() == HloOpcode::kDot) {
-      // Don't fold dots with a 1D operand.
-      if ((instruction->operand(0)->shape().rank() < 2) ||
-          (instruction->operand(1)->shape().rank() < 2)) {
-        return Status::OK();
-      }
-
-      OperandIndices operand_indices;
-      for (int64_t i = 0; i < 2; ++i) {
-        if (!IsNonIdentityTranspose(instruction->operand(i))) {
-          continue;
-        }
-
-        TF_ASSIGN_OR_RETURN(bool can_fold_operand,
-                            dot_can_fold_transpose_operand_(*instruction, i));
-
-        if (can_fold_operand) {
-          operand_indices.push_back(i);
-        }
-      }
-
+    {
+      OperandIndices operand_indices =
+          CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_);
       if (!operand_indices.empty()) {
         foldable_dots.emplace_back(instruction, operand_indices);
       }
     }
-
     {
       OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
           *instruction, transposable_conv_operands_);
       if (!operand_indices.empty()) {
-        foldable_convolutions.emplace_back(instruction, operand_indices);
+        foldable_convolutions.emplace_back(
+            std::make_pair(instruction, operand_indices));
       }
     }
     return Status::OK();
@@ -248,28 +261,5 @@
   return changed;
 }
 
-/*static*/ StatusOr<bool> TransposeFolding::IsRowColumnTransposeDotOperand(
-    const HloInstruction& dot, int64_t operand_idx) {
-  TF_RET_CHECK(dot.opcode() == HloOpcode::kDot);
-  TF_RET_CHECK(dot.operand_count() > operand_idx);
-
-  const HloInstruction& transpose = *dot.operand(operand_idx);
-  TF_RET_CHECK(transpose.opcode() == HloOpcode::kTranspose);
-
-  const DotDimensionNumbers& dot_dims = dot.dot_dimension_numbers();
-
-  auto batch_dims = (operand_idx == 0) ? dot_dims.lhs_batch_dimensions()
-                                       : dot_dims.rhs_batch_dimensions();
-
-  auto contracting_dims = (operand_idx == 0)
-                              ? dot_dims.lhs_contracting_dimensions()
-                              : dot_dims.rhs_contracting_dimensions();
-
-  return (batch_dims.size() == transpose.shape().rank() - 2) &&
-         (contracting_dims.size() == 1) &&
-         absl::c_all_of(batch_dims, [&](int64_t dim) {
-           return transpose.dimensions(dim) == dim;
-         });
-}
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 1a514a1..d94fcff 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -16,8 +16,6 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
 
-#include <functional>
-
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
 
@@ -31,11 +29,10 @@
 
   // Returns the set of foldable operands for a given HLO and some candidate
   // operands.
-  using TransposableConvOperandsFn = std::function<OperandIndices(
-      const HloInstruction&, const OperandIndices&)>;
-
-  using CanFoldTransposeOperand = std::function<StatusOr<bool>(
-      const HloInstruction&, int64_t /*operand_idx*/)>;
+  using FoldableOperands = std::function<OperandIndices(const HloInstruction&,
+                                                        const OperandIndices&)>;
+  using TransposableGemmOperandsFn = FoldableOperands;
+  using TransposableConvOperandsFn = FoldableOperands;
 
   // Helper function to explicitly not fold transposes.
   static OperandIndices NeverFoldTranspose(const HloInstruction&,
@@ -49,26 +46,24 @@
     return ids;
   }
 
-  // `dot_can_fold_transpose_operand` returns whether the dot operation can fold
-  // in the given transpose operand.
+  // transposable_gemm_operands returns the set of operands it wants to fold if
+  // the instruction argument is implemented as a GEMM kernel that supports
+  // transposing its arguments.
   //
   // transposable_conv_operands returns the set of operands it wants to fold if
   // the instruction argument is implemented as a convolution that supports
   // transposing its arguments.
   explicit TransposeFolding(
-      CanFoldTransposeOperand dot_can_fold_transpose_operand =
-          IsRowColumnTransposeDotOperand,
+      TransposableGemmOperandsFn transposable_gemm_operands =
+          AlwaysFoldTranspose,
       TransposableConvOperandsFn transposable_conv_operands =
           AlwaysFoldTranspose);
   absl::string_view name() const override { return "transpose-folding"; }
 
   StatusOr<bool> Run(HloModule* module) override;
 
-  static StatusOr<bool> IsRowColumnTransposeDotOperand(
-      const HloInstruction& dot, int64_t operand_idx);
-
  private:
-  CanFoldTransposeOperand dot_can_fold_transpose_operand_;
+  TransposableGemmOperandsFn transposable_gemm_operands_;
   TransposableConvOperandsFn transposable_conv_operands_;
 };
 
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index c54aebe..0cd706c 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -65,7 +65,7 @@
                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
 }
 
-TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDimByDefault) {
+TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) {
   constexpr absl::string_view kHloString = R"(
 HloModule FoldDotTranspose
 
@@ -82,31 +82,6 @@
   EXPECT_THAT(TransposeFolding().Run(module.get()), IsOkAndHolds(false));
 }
 
-TEST_F(TransposeFoldingTest, FoldTransposeOfBatchWhenPermitted) {
-  constexpr absl::string_view kHloString = R"(
-HloModule FoldDotTranspose
-
-ENTRY entry_computation {
-  x = f32[5,2,3] parameter(0)
-  y = f32[3,5,4] parameter(1)
-  transpose = f32[5,3,4] transpose(y), dimensions={1,0,2}
-  ROOT dot = f32[5,2,4] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseAndReturnVerifiedModule(kHloString));
-
-  TransposeFolding transpose_folding(
-      /*dot_can_fold_transpose_operand=*/[](const HloInstruction&, int64_t) {
-        return true;
-      });
-  EXPECT_THAT(transpose_folding.Run(module.get()), IsOkAndHolds(true));
-  EXPECT_THAT(module->entry_computation()->root_instruction(),
-              op::Dot(op::Parameter(0), op::Parameter(1),
-                      /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/0));
-}
-
 TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) {
   constexpr absl::string_view kHloString = R"(
 HloModule FoldDotTranspose
diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc
index 7dc59ae..05c2604 100644
--- a/tensorflow/core/kernels/matmul_op_fused.cc
+++ b/tensorflow/core/kernels/matmul_op_fused.cc
@@ -326,20 +326,14 @@
     // The cublasLt views the matrix as column major. Considering A*B=C is
     // equivalent to B.t*A.t=C.t (.t=transpose), we swap the A and B and view
     // them in the column major dimensions.
-    se::blas::MatrixDescriptor lhs_matrix = {
-        b_ptr,
-        /*leading_dim_stride=*/trans_b ? k : n,
-        /*batch_stride=*/k * n, trans[trans_b ? 1 : 0]};
-    se::blas::MatrixDescriptor rhs_matrix = {
-        a_ptr,
-        /*leading_dim_stride=*/trans_a ? m : k,
-        /*batch_stride=*/m * k, trans[trans_a ? 1 : 0]};
+    se::blas::MatrixDescriptor lhs_matrix = {b_ptr, trans[trans_b ? 1 : 0], n,
+                                             k, n * k};
+    se::blas::MatrixDescriptor rhs_matrix = {a_ptr, trans[trans_a ? 1 : 0], k,
+                                             m, k * m};
     se::blas::MatrixDescriptor output_matrix = {
-        c_ptr, /*leading_dim_stride=*/n, /*batch_stride=*/m * n,
-        se::blas::Transpose::kNoTranspose};
-    auto plan_and_algorithms_or =
-        se::GetPlanAndAlgorithms(stream, matmul_params, 1, n, m, k, dtype,
-                                 lhs_matrix, rhs_matrix, output_matrix);
+        c_ptr, se::blas::Transpose::kNoTranspose, n, m, n * m};
+    auto plan_and_algorithms_or = se::GetPlanAndAlgorithms(
+        stream, matmul_params, 1, dtype, lhs_matrix, rhs_matrix, output_matrix);
     OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
     const auto* plan_and_algorithms =
         plan_and_algorithms_or.ConsumeValueOrDie();
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ecfdf2c..5827605 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -228,9 +228,14 @@
 // dimensions.
 struct MatrixDescriptor {
   DeviceMemoryBase data;
-  int64_t leading_dim_stride;
-  int64_t batch_stride;
   Transpose transpose;
+  int64_t num_rows;
+  int64_t num_cols;
+  int64_t stride;
+
+  int64_t reduced_dim() const {
+    return transpose == Transpose::kTranspose ? num_rows : num_cols;
+  }
 
   template <typename T>
   DeviceMemory<T> cast() const {
diff --git a/tensorflow/stream_executor/matmul_util.cc b/tensorflow/stream_executor/matmul_util.cc
index 9343b8d..b3e29e7 100644
--- a/tensorflow/stream_executor/matmul_util.cc
+++ b/tensorflow/stream_executor/matmul_util.cc
@@ -94,14 +94,44 @@
   }
 }
 
-namespace {
+port::StatusOr<const blas::PlanAndAlgorithms*> GetPlanAndAlgorithms(
+    Stream* stream, BatchMatmulParameters matmul_parameters, int64_t batch_size,
+    tensorflow::DataType dtype, blas::MatrixDescriptor lhs_matrix,
+    blas::MatrixDescriptor rhs_matrix, blas::MatrixDescriptor output_matrix) {
+  static const int64_t max_scratch_size =
+      GetWorkspaceLimit(1LL << 32);  // 4GB by default
+  static const int64_t max_autotune_algorithm_count =
+      MatmulMaxAutotuneAlgorithmCount();
+  const blas::PlanAndAlgorithms* plan_and_algorithms =
+      BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
+  if (!plan_and_algorithms) {
+    TF_ASSIGN_OR_RETURN(
+        blas::BlasLtMatmulPlanParams plan_params,
+        CreatePlanParams(batch_size, dtype, matmul_parameters.GetEpilogOp(),
+                         lhs_matrix, rhs_matrix, output_matrix));
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<blas::IBlasLtMatmulPlan> plan,
+                        stream->parent()->CreateBlasLtMatmulPlan(plan_params));
+    TF_ASSIGN_OR_RETURN(
+        std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> algorithms,
+        stream->parent()->GetBlasLtMatmulAlgorithms(
+            plan.get(), max_scratch_size,
+            /* max_algorithm_count */ max_autotune_algorithm_count));
+
+    plan_and_algorithms = BatchMatmulPlanMapSingleton::GetInstance()->Insert(
+        matmul_parameters, {std::move(plan), std::move(algorithms)});
+  }
+  return plan_and_algorithms;
+}
 
 port::StatusOr<blas::BlasLtMatmulPlanParams> CreatePlanParams(
-    int64_t batch_size, int64_t m, int64_t n, int64_t k,
-    tensorflow::DataType dtype, blas::Epilogue epilog_op,
+    int64_t batch_size, tensorflow::DataType dtype, blas::Epilogue epilog_op,
     blas::MatrixDescriptor lhs_matrix, blas::MatrixDescriptor rhs_matrix,
     blas::MatrixDescriptor output_matrix) {
   blas::BlasLtMatmulPlanParams plan_params;
+  int64_t m = output_matrix.num_rows;
+  int64_t n = output_matrix.num_cols;
+  int64_t k = lhs_matrix.reduced_dim();
+
   TF_ASSIGN_OR_RETURN(blas::DataType blas_dtype, GetBlasDataType(dtype));
   plan_params.ab_type = blas_dtype;
   plan_params.c_type = blas_dtype;
@@ -120,13 +150,17 @@
   plan_params.m = m;
   plan_params.n = n;
   plan_params.k = k;
-  plan_params.lda = lhs_matrix.leading_dim_stride;
-  plan_params.ldb = rhs_matrix.leading_dim_stride;
-  plan_params.ldc = output_matrix.leading_dim_stride;
+  plan_params.lda = lhs_matrix.num_rows;
+  plan_params.ldb = rhs_matrix.num_rows;
+  plan_params.ldc = output_matrix.num_rows;
   plan_params.batch_count = batch_size;
-  plan_params.stride_a = lhs_matrix.batch_stride;
-  plan_params.stride_b = rhs_matrix.batch_stride;
-  plan_params.stride_c = output_matrix.batch_stride;
+
+  bool broadcast = batch_size == 1;
+  int64_t lhs_stride = broadcast ? 0 : lhs_matrix.stride;
+  int64_t rhs_stride = broadcast ? 0 : rhs_matrix.stride;
+  plan_params.stride_a = lhs_stride;
+  plan_params.stride_b = rhs_stride;
+  plan_params.stride_c = output_matrix.stride;
 
   if (VLOG_IS_ON(4)) {
     bool trans_x = lhs_matrix.transpose == blas::Transpose::kTranspose;
@@ -148,37 +182,4 @@
   return plan_params;
 }
 
-}  // namespace
-
-port::StatusOr<const blas::PlanAndAlgorithms*> GetPlanAndAlgorithms(
-    Stream* stream, BatchMatmulParameters matmul_parameters, int64_t batch_size,
-    int64_t m, int64_t n, int64_t k, tensorflow::DataType dtype,
-    blas::MatrixDescriptor lhs_matrix, blas::MatrixDescriptor rhs_matrix,
-    blas::MatrixDescriptor output_matrix) {
-  static const int64_t max_scratch_size =
-      GetWorkspaceLimit(1LL << 32);  // 4GB by default
-  static const int64_t max_autotune_algorithm_count =
-      MatmulMaxAutotuneAlgorithmCount();
-  const blas::PlanAndAlgorithms* plan_and_algorithms =
-      BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
-  if (!plan_and_algorithms) {
-    TF_ASSIGN_OR_RETURN(
-        blas::BlasLtMatmulPlanParams plan_params,
-        CreatePlanParams(batch_size, m, n, k, dtype,
-                         matmul_parameters.GetEpilogOp(), lhs_matrix,
-                         rhs_matrix, output_matrix));
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<blas::IBlasLtMatmulPlan> plan,
-                        stream->parent()->CreateBlasLtMatmulPlan(plan_params));
-    TF_ASSIGN_OR_RETURN(
-        std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> algorithms,
-        stream->parent()->GetBlasLtMatmulAlgorithms(
-            plan.get(), max_scratch_size,
-            /* max_algorithm_count */ max_autotune_algorithm_count));
-
-    plan_and_algorithms = BatchMatmulPlanMapSingleton::GetInstance()->Insert(
-        matmul_parameters, {std::move(plan), std::move(algorithms)});
-  }
-  return plan_and_algorithms;
-}
-
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/matmul_util.h b/tensorflow/stream_executor/matmul_util.h
index a3038aa..447b272 100644
--- a/tensorflow/stream_executor/matmul_util.h
+++ b/tensorflow/stream_executor/matmul_util.h
@@ -163,7 +163,11 @@
 
 port::StatusOr<const blas::PlanAndAlgorithms*> GetPlanAndAlgorithms(
     Stream* stream, BatchMatmulParameters matmul_parameters, int64_t batch_size,
-    int64_t m, int64_t n, int64_t k, tensorflow::DataType dtype,
+    tensorflow::DataType dtype, blas::MatrixDescriptor lhs_matrix,
+    blas::MatrixDescriptor rhs_matrix, blas::MatrixDescriptor output_matrix);
+
+port::StatusOr<blas::BlasLtMatmulPlanParams> CreatePlanParams(
+    int64_t batch_size, tensorflow::DataType dtype, blas::Epilogue epilog,
     blas::MatrixDescriptor lhs_matrix, blas::MatrixDescriptor rhs_matrix,
     blas::MatrixDescriptor output_matrix);