#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
// Matrix multiplication before the rewrite.
// This function should never return "true" on instructions after
// GemmRewriter pass has finished.
bool IsMatrixMultiplication(const HloInstruction& dot);
inline constexpr int64_t WarpSize() { return 32; }
// Need at least 1024 threads/block for reasonable tree reduction
// performance (assuming all data fits).
inline constexpr int64_t MinThreadsXRowReduction() { return 1024; }
// When doing batched row reduction, how big the batch dimension could be.
inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; }
// Returns true if `hlo` will be implemented as a call to a cuSolver routine.
// This returns true if `hlo` is a CustomCall HLO with a call target equal to
// one of the kCusolver... constants, but returns *false* for HLOs with
// say, a kCholesky opcode.
bool IsCustomCallToCusolver(const HloInstruction& hlo);
// Cholesky decomposition. Takes a (batched) matrix as input, and returns a
// tuple of (result, workspace, info), where result is the result of the
// Cholesky decomposition, workspace is scratch space for cuSolver, and info
// is a success/failure code per batch element.
extern const char* const kCusolverCholeskyCallTarget;
// Returns true if either the dimensions being reduced or the dimensions being
// kept are contiguous in the input of the reduce instruction.
bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce);
// MLIR variant.
bool IsReductionFromOrToContiguousDimensions(mlir::Operation* op);
// Returns whether unnested_hlo is an input fusion whose root is either a slice
// or a tuple of slices. If verify_no_strides is true, returns false unless all
// ROOT slices have no strides.
bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
bool verify_no_strides);
struct ReductionDimensions {
// Indicates whether the reduction is a row reduction or a column reduction.
bool is_row_reduction;
// Contains the size of the three contiguous components for
// the reduction [depth, height, width] (major-to-minor ordering).
// For row reduction, we do: [D, H, W] -> [D, H].
// For column reduction, we do: [D, H, W] -> [D, W].
std::array<int64_t, 3> dimensions;
// Given the input shape and dimensions to reduce for a reduction, returns
// ReductionDimensions.
// Prerequisite: the reduction instruction passes the check
// IsReductionFromOrToContiguousDimensions, which guarantees either the
// dimensions to reduce or the dimensions to keep are consecutive.
ReductionDimensions GetReductionKindAndContiguousComponents(
const HloInstruction& reduce);
ReductionDimensions GetReductionKindAndContiguousComponents(
mlir::Operation* reduce);
// Get tiling per thread for the given reduction in dimensions [D, H, W].
std::array<int64_t, 3> GetReductionTiling(
const ReductionDimensions& reduction_dimensions,
se::CudaComputeCapability cuda_compute_capability);
// Emits call to "vprintf" with given format and arguments.
llvm::Value* EmitPrintf(absl::string_view fmt,
absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder);
// Emits code to shuffle data between threads of a warp. This has the same
// semantics as the PTX "shfl.sync.down" instruction but works for values that
// aren't 32 bits in size. The last operand of the emitted "shfl" is
// `WarpSize() - 1`.
// This function emits a "full-warp" shuffle, which all threads of a warp
// participate in. *Do not use this function from a divergent context:* You
// can't correctly do so on both Volta and earlier GPUs.
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
// Emits code that determines whether the current thread is thread 0 within
// block 0 of the kernel.
llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
// Returns whether the output of a fusion with reduction are consistent with
// `first_reduce`.
bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
const HloInstruction* first_reduce);
inline bool AreFusedReductionOutputsConsistent(
absl::Span<const HloInstruction* const> output_instructions,
const HloInstruction* first_reduce) {
return absl::c_all_of(output_instructions, [=](const HloInstruction* inst) {
return IsFusedReductionOutputConsistent(inst, first_reduce);
inline std::string MlirToString(mlir::Operation* op) {
std::string s;
llvm::raw_string_ostream os(s);
return s;
inline std::string MlirToString(const mlir::Location& loc) {
std::string s;
llvm::raw_string_ostream os(s);
return s;
int PartitionLmhloOperandsAndOutputs(mlir::Operation* op);
std::vector<mlir::Value> GetHloOperands(mlir::Operation* op);
std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op);
bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand);
template <typename T>
std::vector<T> ToStdVector(const llvm::SmallVectorImpl<T>& v) {
return std::vector<T>(v.begin(), v.end());
StatusOr<BufferAllocation::Slice> GetAllocationSlice(
mlir::Value v, absl::Span<const BufferAllocation> allocations,
std::string* constant_name = nullptr);
bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
mlir::lmhlo::FusionOp fusion,
absl::Span<const BufferAllocation> allocations);
Shape GetShape(mlir::Value value);
// Returns whether the given reduction can be safely generated without atomics:
// that is, at most one block will write to every output element.
bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
const std::array<int64_t, 3>& reduction_tiling);
// Description of how to emit a given transposition.
// On a group of input parameters that are 0-2-1 transpose of the outputs
// of a fusion kernel, stores the input parameters that are safe for the
// shared memory transpose implementation and the dimension permutation.
// When a tile based shared memory transpose is used to implement an input
// with 0-2-1 transpose, we preload a tile of the input elements [z, y..y+31,
// x..x+31] to compute the output tile elements of the same indices.
// Preloading the input tile this way is only safe when the computation of the
// output tile elements do not need any input element outside the preloaded
// tile. We inspect all the transitive users of the input parameter up to the
// fusion root instruction to see if we can find any instruction that can make
// preloading the input tile unsafe.
struct TransposeDimsAndParams {
// Permutation of the dimensions relative to output.
Vector3 dims;
// Permuted inputs.
std::vector<int64_t> params;
std::string ToString() const {
return absl::StrFormat("{dims={%s}, params={%s}}",
absl::StrJoin(dims, ", "),
absl::StrJoin(params, ", "));
// Attempts to match 021 transpose on the given fusion and return a
// transposition description.
// Precondition: input is a fused computation, with kCopy as a root.
std::optional<TransposeDimsAndParams> Match021Transpose(
const HloComputation* fused_computation);
// If one or multiple `operand_shapes` are the same 0-2-1 transpose of
// `output_shape` in 0-1-2, return the dimensions of the normalized shape.
std::optional<TransposeDimsAndParams> FindTranspose021DimsAndParameters(
const absl::Span<Shape const>& operand_shapes, const Shape& output_shape);
// Whether fusing `producer` into `consumer` results in a fusion op that will be
// emitted as shared memory transpose.
bool FusionCanBeEmittedAsShmemTranspose(const HloInstruction& producer,
const HloInstruction& consumer);
} // namespace gpu
} // namespace xla