| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/types/span.h" |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" |
| #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h" |
| #include "tensorflow/compiler/xla/layout_util.h" |
| #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/stream_executor/blas.h" |
| |
| namespace xla { |
| namespace gpu { |
| namespace { |
| |
| void TransposeMatrixDesc(se::blas::MatrixDescriptor& matrix_desc) { |
| matrix_desc.transpose = |
| (matrix_desc.transpose == se::blas::Transpose::kNoTranspose) |
| ? se::blas::Transpose::kTranspose |
| : se::blas::Transpose::kNoTranspose; |
| } |
| |
| } // namespace |
| |
| StatusOr<std::vector<int64_t>> GetNonContractingDims( |
| const Shape& shape, absl::Span<const int64_t> batch_dims, |
| absl::Span<const int64_t> contracting_dims) { |
| std::vector<int64_t> non_contracting_dims; |
| // This is O(rank**2), but we expect rank to be small. |
| for (int64_t dim = 0; dim < shape.rank(); ++dim) { |
| bool is_batch = absl::c_count(batch_dims, dim) != 0; |
| bool is_contracting = absl::c_count(contracting_dims, dim) != 0; |
| TF_RET_CHECK(!(is_batch && is_contracting)); |
| if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim); |
| } |
| |
| TF_RET_CHECK(batch_dims.size() + contracting_dims.size() + |
| non_contracting_dims.size() == |
| shape.rank()); |
| return non_contracting_dims; |
| } |
| |
| StatusOr<Shape> GetBatchRowColumnShape(const Shape& shape, |
| absl::Span<const int64_t> batch_dims, |
| absl::Span<const int64_t> row_dims, |
| absl::Span<const int64_t> col_dims) { |
| TF_RET_CHECK(shape.has_layout()); |
| |
| // Start by classifying each physical dimension as batch, row, or column. |
| // This is O(rank**2), but we expect rank to be small. |
| std::vector<int64_t> minor_to_major; |
| minor_to_major.reserve(shape.rank()); |
| for (int64_t dim : shape.layout().minor_to_major()) { |
| size_t batch_matches = absl::c_count(batch_dims, dim); |
| size_t row_matches = absl::c_count(row_dims, dim); |
| size_t col_matches = absl::c_count(col_dims, dim); |
| size_t total_matches = batch_matches + row_matches + col_matches; |
| TF_RET_CHECK(total_matches == 1) << "dimensions incomplete or overlapping"; |
| minor_to_major.push_back(batch_matches ? 0 : row_matches ? 1 : 2); |
| } |
| |
| // Remove repeated items (e.g. `[0, 0, 2, 1, 1, 1]` -> `[0, 2, 1]`). |
| minor_to_major.erase( |
| std::unique(minor_to_major.begin(), minor_to_major.end()), |
| minor_to_major.end()); |
| |
| // In order to "collapse" the shape to 3D, each of the batch, row, and column |
| // dims must be in consecutive physical dimensions. |
| TF_RET_CHECK(minor_to_major.size() == (batch_dims.empty() ? 2 : 3)); |
| |
| if (batch_dims.empty()) minor_to_major.push_back(0); |
| |
| auto dim_size = [&](absl::Span<const int64_t> dims) { |
| return absl::c_accumulate(dims, 1, [&](int64_t size, int64_t dim) { |
| return size * shape.dimensions(dim); |
| }); |
| }; |
| |
| return ShapeUtil::MakeShapeWithLayout( |
| shape.element_type(), |
| {dim_size(batch_dims), dim_size(row_dims), dim_size(col_dims)}, |
| minor_to_major); |
| } |
| |
| // Returns the matrix layout for a logical shape (batch, rows, columns). |
| /*static*/ StatusOr<MatrixLayout> MatrixLayout::For(const Shape& shape) { |
| TF_RET_CHECK(shape.rank() == 3); |
| TF_RET_CHECK(shape.has_layout()); |
| |
| int64_t batch_size = shape.dimensions(0); |
| int64_t num_rows = shape.dimensions(1); |
| int64_t num_cols = shape.dimensions(2); |
| |
| MatrixLayout::Order order = MatrixLayout::Order::kRowMajor; |
| int64_t leading_dim_stride = num_cols; |
| int64_t batch_stride = num_rows * num_cols; |
| |
| // `MatrixLayout`, like BLAS, uses only two strides, so either the row or |
| // column must be contiguous in memory (i.e. most minor physical dimension). |
| absl::Span<const int64_t> minor_to_major = shape.layout().minor_to_major(); |
| switch (64 * minor_to_major[2] + 8 * minor_to_major[1] + minor_to_major[0]) { |
| case 012: // (B,R,C) (major-to-minor) |
| break; |
| case 021: // (B,C,R) |
| order = MatrixLayout::Order::kColumnMajor; |
| leading_dim_stride = num_rows; |
| break; |
| case 0102: // (R,B,C) |
| leading_dim_stride = batch_size * num_cols; |
| batch_stride = num_cols; |
| break; |
| case 0201: // (C,B,R) |
| order = MatrixLayout::Order::kColumnMajor; |
| leading_dim_stride = batch_size * num_rows; |
| batch_stride = num_rows; |
| break; |
| default: |
| return Unimplemented("batch in most minor dimension"); |
| } |
| |
| if (batch_size == 1) batch_stride = 0; |
| return MatrixLayout{ |
| shape.element_type(), num_rows, num_cols, order, |
| leading_dim_stride, batch_size, batch_stride, |
| }; |
| } |
| |
| /*static*/ StatusOr<MatrixLayout> MatrixLayout::For( |
| const Shape& shape, absl::Span<const int64_t> batch_dims, |
| absl::Span<const int64_t> row_dims, absl::Span<const int64_t> col_dims) { |
| TF_ASSIGN_OR_RETURN( |
| Shape batch_row_col_shape, |
| GetBatchRowColumnShape(shape, batch_dims, row_dims, col_dims)); |
| return MatrixLayout::For(batch_row_col_shape); |
| } |
| |
| namespace { |
| |
| bool IsBlasPlansCompatibleType(PrimitiveType type) { |
| switch (type) { |
| case F16: |
| case F32: |
| case F64: |
| case C64: |
| case C128: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| } // namespace |
| |
| /*static*/ StatusOr<GemmConfig> GemmConfig::For( |
| const Shape& lhs_shape, absl::Span<const int64_t> lhs_batch_dims, |
| absl::Span<const int64_t> lhs_contracting_dims, const Shape& rhs_shape, |
| absl::Span<const int64_t> rhs_batch_dims, |
| absl::Span<const int64_t> rhs_contracting_dims, const Shape& output_shape, |
| double alpha_real, double alpha_imag, double beta, |
| absl::optional<int64_t> algorithm, bool use_cublaslt) { |
| absl::Span<const int64_t> lhs_col_dims = lhs_contracting_dims; |
| TF_ASSIGN_OR_RETURN( |
| std::vector<int64_t> lhs_row_dims, |
| GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims)); |
| |
| TF_ASSIGN_OR_RETURN( |
| MatrixLayout lhs_layout, |
| MatrixLayout::For(lhs_shape, lhs_batch_dims, lhs_row_dims, lhs_col_dims)); |
| |
| absl::Span<const int64_t> rhs_row_dims = rhs_contracting_dims; |
| TF_ASSIGN_OR_RETURN( |
| std::vector<int64_t> rhs_col_dims, |
| GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims)); |
| |
| TF_ASSIGN_OR_RETURN( |
| MatrixLayout rhs_layout, |
| MatrixLayout::For(rhs_shape, rhs_batch_dims, rhs_row_dims, rhs_col_dims)); |
| |
| int64_t num_batch_dims = |
| std::max(lhs_batch_dims.size(), rhs_batch_dims.size()); |
| |
| TF_RET_CHECK(output_shape.rank() == |
| num_batch_dims + lhs_row_dims.size() + rhs_col_dims.size()); |
| |
| std::vector<int64_t> output_dims(output_shape.rank()); |
| absl::c_iota(output_dims, 0); |
| |
| auto output_batch_dims = |
| absl::Span<const int64_t>(output_dims).first(num_batch_dims); |
| auto output_row_dims = absl::Span<const int64_t>(output_dims) |
| .subspan(num_batch_dims, lhs_row_dims.size()); |
| auto output_col_dims = |
| absl::Span<const int64_t>(output_dims).last(rhs_col_dims.size()); |
| |
| TF_ASSIGN_OR_RETURN(MatrixLayout output_layout, |
| MatrixLayout::For(output_shape, output_batch_dims, |
| output_row_dims, output_col_dims)); |
| |
| // TODO(cjfj): We should also check that the batch, contracting and |
| // non-contracting dimensions match in size and relative physical location. |
| TF_RET_CHECK(lhs_layout.num_cols == rhs_layout.num_rows); |
| TF_RET_CHECK(output_layout.num_rows == lhs_layout.num_rows); |
| TF_RET_CHECK(output_layout.num_cols == rhs_layout.num_cols); |
| TF_RET_CHECK((lhs_layout.batch_size == output_layout.batch_size) || |
| (lhs_layout.batch_size == 1)); |
| TF_RET_CHECK((rhs_layout.batch_size == output_layout.batch_size) || |
| (rhs_layout.batch_size == 1)); |
| |
| use_cublaslt &= IsBlasPlansCompatibleType(output_shape.element_type()); |
| |
| switch (output_shape.element_type()) { |
| case F16: |
| case BF16: |
| case F32: |
| case F64: |
| TF_RET_CHECK(alpha_imag == 0); |
| break; |
| case C64: |
| case C128: |
| break; |
| case S32: |
| TF_RET_CHECK(alpha_imag == 0); |
| if (lhs_layout.dtype != PrimitiveType::S8 || |
| rhs_layout.dtype != PrimitiveType::S8) { |
| return InternalError( |
| "For int32 gemm output only int8 input is supported, got input: " |
| "%s, %s", |
| primitive_util::LowercasePrimitiveTypeName(lhs_layout.dtype), |
| primitive_util::LowercasePrimitiveTypeName(rhs_layout.dtype)); |
| } |
| break; |
| default: |
| return InternalError("Unexpected GEMM datatype: %s", |
| primitive_util::LowercasePrimitiveTypeName( |
| output_shape.element_type())); |
| } |
| |
| return GemmConfig{ |
| lhs_layout, rhs_layout, output_layout, {alpha_real, alpha_imag}, |
| beta, algorithm, use_cublaslt, |
| }; |
| } |
| |
| /*static*/ StatusOr<GemmConfig> GemmConfig::For(const HloInstruction* gemm) { |
| TF_ASSIGN_OR_RETURN(GemmBackendConfig config, |
| gemm->backend_config<GemmBackendConfig>()); |
| |
| absl::optional<int64_t> algorithm; |
| if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { |
| algorithm = config.selected_algorithm(); |
| } |
| |
| const Shape& lhs_shape = gemm->operand(0)->shape(); |
| const Shape& rhs_shape = gemm->operand(1)->shape(); |
| const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers(); |
| bool use_cublaslt = |
| gemm->GetModule()->config().debug_options().xla_gpu_enable_cublaslt(); |
| |
| return GemmConfig::For( |
| lhs_shape, dot_dims.lhs_batch_dimensions(), |
| dot_dims.lhs_contracting_dimensions(), rhs_shape, |
| dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), |
| /*output_shape=*/gemm->shape(), config.alpha_real(), config.alpha_imag(), |
| config.beta(), algorithm, use_cublaslt); |
| } |
| |
| /*static*/ StatusOr<GemmConfig> GemmConfig::For(mlir::Operation* op, |
| bool use_cublaslt) { |
| auto get_config = [&](auto op, llvm::APFloat beta) { |
| mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.dot_dimension_numbers(); |
| |
| absl::optional<int64_t> algorithm; |
| if (op.algorithm()) algorithm = *op.algorithm(); |
| |
| return GemmConfig::For( |
| GetShape(op.lhs()), dot_dims.getLhsBatchingDimensions(), |
| dot_dims.getLhsContractingDimensions(), GetShape(op.rhs()), |
| dot_dims.getRhsBatchingDimensions(), |
| dot_dims.getRhsContractingDimensions(), GetShape(op.output()), |
| op.alpha_real().convertToDouble(), op.alpha_imag().convertToDouble(), |
| beta.convertToDouble(), algorithm, use_cublaslt); |
| }; |
| |
| if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op)) |
| return get_config(gemm, llvm::APFloat(0.)); |
| |
| auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(op); |
| TF_RET_CHECK(gemm != nullptr); |
| return get_config(gemm, gemm.beta()); |
| } |
| |
| se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout, |
| se::DeviceMemoryBase data) { |
| // TODO(cjfj): Add support for batch not in most major physical dimension. |
| CHECK((layout.batch_stride == 0) || |
| (layout.batch_stride == layout.num_rows * layout.num_cols)); |
| bool transpose = layout.order != MatrixLayout::Order::kColumnMajor; |
| return { |
| data, |
| transpose ? se::blas::Transpose::kTranspose |
| : se::blas::Transpose::kNoTranspose, |
| transpose ? layout.num_cols : layout.num_rows, |
| transpose ? layout.num_rows : layout.num_cols, |
| layout.batch_stride, |
| }; |
| } |
| |
| void MakeBlasGemmCompatible(se::blas::MatrixDescriptor& lhs, |
| se::blas::MatrixDescriptor& rhs, |
| se::blas::MatrixDescriptor& output) { |
| // BLAS GeMM doesn't support transposed output, but we can use the identity: |
| // C^T = (A @ B)^T = B^T @ A^T. |
| if (output.transpose == se::blas::Transpose::kTranspose) { |
| std::swap(lhs, rhs); |
| TransposeMatrixDesc(lhs); |
| TransposeMatrixDesc(rhs); |
| TransposeMatrixDesc(output); |
| } |
| } |
| |
| } // namespace gpu |
| } // namespace xla |