blob: c8661ba8812632e5ba42df294020abc50cf05b45 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/blas.h"
namespace xla {
namespace gpu {
namespace {
void TransposeMatrixDesc(se::blas::MatrixDescriptor& matrix_desc) {
matrix_desc.transpose =
(matrix_desc.transpose == se::blas::Transpose::kNoTranspose)
? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose;
}
} // namespace
StatusOr<std::vector<int64_t>> GetNonContractingDims(
const Shape& shape, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> contracting_dims) {
std::vector<int64_t> non_contracting_dims;
// This is O(rank**2), but we expect rank to be small.
for (int64_t dim = 0; dim < shape.rank(); ++dim) {
bool is_batch = absl::c_count(batch_dims, dim) != 0;
bool is_contracting = absl::c_count(contracting_dims, dim) != 0;
TF_RET_CHECK(!(is_batch && is_contracting));
if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim);
}
TF_RET_CHECK(batch_dims.size() + contracting_dims.size() +
non_contracting_dims.size() ==
shape.rank());
return non_contracting_dims;
}
StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape,
absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> row_dims,
absl::Span<const int64_t> col_dims) {
TF_RET_CHECK(shape.has_layout());
// Start by classifying each physical dimension as batch, row, or column.
// This is O(rank**2), but we expect rank to be small.
std::vector<int64_t> minor_to_major;
minor_to_major.reserve(shape.rank());
for (int64_t dim : shape.layout().minor_to_major()) {
size_t batch_matches = absl::c_count(batch_dims, dim);
size_t row_matches = absl::c_count(row_dims, dim);
size_t col_matches = absl::c_count(col_dims, dim);
size_t total_matches = batch_matches + row_matches + col_matches;
TF_RET_CHECK(total_matches == 1) << "dimensions incomplete or overlapping";
minor_to_major.push_back(batch_matches ? 0 : row_matches ? 1 : 2);
}
// Remove repeated items (e.g. `[0, 0, 2, 1, 1, 1]` -> `[0, 2, 1]`).
minor_to_major.erase(
std::unique(minor_to_major.begin(), minor_to_major.end()),
minor_to_major.end());
// In order to "collapse" the shape to 3D, each of the batch, row, and column
// dims must be in consecutive physical dimensions.
TF_RET_CHECK(minor_to_major.size() == (batch_dims.empty() ? 2 : 3));
if (batch_dims.empty()) minor_to_major.push_back(0);
auto dim_size = [&](absl::Span<const int64_t> dims) {
return absl::c_accumulate(dims, 1, [&](int64_t size, int64_t dim) {
return size * shape.dimensions(dim);
});
};
return ShapeUtil::MakeShapeWithLayout(
shape.element_type(),
{dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)},
minor_to_major);
}
// Returns the matrix layout for a logical shape (batch, rows, columns).
/*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape) {
TF_RET_CHECK(shape.rank() == 3);
TF_RET_CHECK(shape.has_layout());
int64_t batch_size = shape.dimensions(0);
int64_t num_rows = shape.dimensions(1);
int64_t num_cols = shape.dimensions(2);
MatrixLayout::Order order = MatrixLayout::Order::kRowMajor;
int64_t leading_dim_stride = num_cols;
int64_t batch_stride = num_rows * num_cols;
// `MatrixLayout`, like BLAS, uses only two strides, so either the row or
// column must be contiguous in memory (i.e. most minor physical dimension).
absl::Span<const int64_t> minor_to_major = shape.layout().minor_to_major();
switch (64 * minor_to_major[2] + 8 * minor_to_major[1] + minor_to_major[0]) {
case 012: // (B,R,C) (major-to-minor)
break;
case 021: // (B,C,R)
order = MatrixLayout::Order::kColumnMajor;
leading_dim_stride = num_rows;
break;
case 0102: // (R,B,C)
leading_dim_stride = batch_size * num_cols;
batch_stride = num_cols;
break;
case 0201: // (C,B,R)
order = MatrixLayout::Order::kColumnMajor;
leading_dim_stride = batch_size * num_rows;
batch_stride = num_rows;
break;
default:
return Unimplemented("batch in most minor dimension");
}
if (batch_size == 1) batch_stride = 0;
return MatrixLayout{
shape.element_type(), num_rows, num_cols, order,
leading_dim_stride, batch_size, batch_stride,
};
}
/*static*/ StatusOr<MatrixLayout> MatrixLayout::For(
const Shape& shape, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> row_dims, absl::Span<const int64_t> col_dims) {
TF_ASSIGN_OR_RETURN(
Shape batch_row_col_shape,
GetBatchRowColumnShape(shape, batch_dims, row_dims, col_dims));
return MatrixLayout::For(batch_row_col_shape);
}
namespace {
bool IsBlasPlansCompatibleType(PrimitiveType type) {
switch (type) {
case F16:
case F32:
case F64:
case C64:
case C128:
return true;
default:
return false;
}
}
} // namespace
/*static*/ StatusOr<GemmConfig> GemmConfig::For(
const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims,
absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape,
absl::Span<const int64_t> rhs_batch_dims,
absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape,
double alpha_real, double alpha_imag, double beta,
absl::optional<int64_t> algorithm, bool use_cublaslt) {
absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims;
TF_ASSIGN_OR_RETURN(
std::vector<int64_t> lhs_row_dims,
GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
TF_ASSIGN_OR_RETURN(
MatrixLayout lhs_layout,
MatrixLayout::For(lhs_shape, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
absl::Span<const int64_t> rhs_row_dims = rhs_contracting_dims;
TF_ASSIGN_OR_RETURN(
std::vector<int64_t> rhs_col_dims,
GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
TF_ASSIGN_OR_RETURN(
MatrixLayout rhs_layout,
MatrixLayout::For(rhs_shape, rhs_batch_dims, rhs_row_dims, rhs_col_dims));
int64_t num_batch_dims =
std::max(lhs_batch_dims.size(), rhs_batch_dims.size());
TF_RET_CHECK(output_shape.rank() ==
num_batch_dims + lhs_row_dims.size() + rhs_col_dims.size());
std::vector<int64_t> output_dims(output_shape.rank());
absl::c_iota(output_dims, 0);
auto output_batch_dims =
absl::Span<const int64_t>(output_dims).first(num_batch_dims);
auto output_row_dims = absl::Span<const int64_t>(output_dims)
.subspan(num_batch_dims, lhs_row_dims.size());
auto output_col_dims =
absl::Span<const int64_t>(output_dims).last(rhs_col_dims.size());
TF_ASSIGN_OR_RETURN(MatrixLayout output_layout,
MatrixLayout::For(output_shape, output_batch_dims,
output_row_dims, output_col_dims));
// TODO(cjfj): We should also check that the batch, contracting and
// non-contracting dimensions match in size and relative physical location.
TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows);
TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows);
TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols);
TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) ||
(lhs_layout.batch_size == 1));
TF_RET_CHECK((rhs_layout.batch_size == output_layout.batch_size) ||
(rhs_layout.batch_size == 1));
use_cublaslt &= IsBlasPlansCompatibleType(output_shape.element_type());
switch (output_shape.element_type()) {
case F16:
case BF16:
case F32:
case F64:
TF_RET_CHECK(alpha_imag == 0);
break;
case C64:
case C128:
break;
case S32:
TF_RET_CHECK(alpha_imag == 0);
if (lhs_layout.dtype != PrimitiveType::S8 ||
rhs_layout.dtype != PrimitiveType::S8) {
return InternalError(
"For int32 gemm output only int8 input is supported, got input: "
"%s, %s",
primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype),
primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype));
}
break;
default:
return InternalError("Unexpected GEMM datatype: %s",
primitive_util::LowercasePrimitiveTypeName(
output_shape.element_type()));
}
return GemmConfig{
lhs_layout, rhs_layout, output_layout, {alpha_real, alpha_imag},
beta, algorithm, use_cublaslt,
};
}
/*static*/ StatusOr<GemmConfig> GemmConfig::For(const HloInstruction* gemm) {
TF_ASSIGN_OR_RETURN(GemmBackendConfig config,
gemm->backend_config<GemmBackendConfig>());
absl::optional<int64_t> algorithm;
if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) {
algorithm = config.selected_algorithm();
}
const Shape& lhs_shape = gemm->operand(0)->shape();
const Shape& rhs_shape = gemm->operand(1)->shape();
const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers();
bool use_cublaslt =
gemm->GetModule()->config().debug_options().xla_gpu_enable_cublaslt();
return GemmConfig::For(
lhs_shape, dot_dims.lhs_batch_dimensions(),
dot_dims.lhs_contracting_dimensions(), rhs_shape,
dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(),
/*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(),
config.beta(), algorithm, use_cublaslt);
}
/*static*/ StatusOr<GemmConfig> GemmConfig::For(mlir::Operation* op,
bool use_cublaslt) {
auto get_config = [&](auto op, llvm::APFloat beta) {
mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.dot_dimension_numbers();
absl::optional<int64_t> algorithm;
if (op.algorithm()) algorithm = *op.algorithm();
return GemmConfig::For(
GetShape(op.lhs()), dot_dims.getLhsBatchingDimensions(),
dot_dims.getLhsContractingDimensions(), GetShape(op.rhs()),
dot_dims.getRhsBatchingDimensions(),
dot_dims.getRhsContractingDimensions(), GetShape(op.output()),
op.alpha_real().convertToDouble(), op.alpha_imag().convertToDouble(),
beta.convertToDouble(), algorithm, use_cublaslt);
};
if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op))
return get_config(gemm, llvm::APFloat(0.));
auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(op);
TF_RET_CHECK(gemm != nullptr);
return get_config(gemm, gemm.beta());
}
se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout,
se::DeviceMemoryBase data) {
// TODO(cjfj): Add support for batch not in most major physical dimension.
CHECK((layout.batch_stride == 0) ||
(layout.batch_stride == layout.num_rows * layout.num_cols));
bool transpose = layout.order != MatrixLayout::Order::kColumnMajor;
return {
data,
transpose ? se::blas::Transpose::kTranspose
: se::blas::Transpose::kNoTranspose,
transpose ? layout.num_cols : layout.num_rows,
transpose ? layout.num_rows : layout.num_cols,
layout.batch_stride,
};
}
void MakeBlasGemmCompatible(se::blas::MatrixDescriptor& lhs,
se::blas::MatrixDescriptor& rhs,
se::blas::MatrixDescriptor& output) {
// BLAS GeMM doesn't support transposed output, but we can use the identity:
// C^T = (A @ B)^T = B^T @ A^T.
if (output.transpose == se::blas::Transpose::kTranspose) {
std::swap(lhs, rhs);
TransposeMatrixDesc(lhs);
TransposeMatrixDesc(rhs);
TransposeMatrixDesc(output);
}
}
} // namespace gpu
} // namespace xla