blob: dc4a9833ab77f0424f519a5c4036f95c528d6df0 [file] [log] [blame]
// Copyright 2022 The TensorFlow Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
#include <cstdint>
#include <utility>
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tfrt/jitrt/custom_call.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
const xla::ServiceExecutableRunOptions);
const xla::DebugOptions);
namespace xla {
namespace gpu {
using llvm::ArrayRef;
using llvm::orc::MangleAndInterner;
using llvm::orc::SymbolMap;
using mlir::failure;
using mlir::LogicalResult;
using mlir::StringRef;
using mlir::succeeded;
using mlir::success;
using tfrt::jitrt::CustomCall;
using tfrt::jitrt::Executable;
namespace se = ::stream_executor;
namespace jitrt = ::tfrt::jitrt;
namespace runtime = ::tfrt::jitrt::runtime;
// Disable all CustomCall checks in optimized build.
static constexpr CustomCall::RuntimeChecks RuntimeChecks() {
#if defined(NDEBUG)
return CustomCall::RuntimeChecks::kNone;
return CustomCall::RuntimeChecks::kDefault;
// -------------------------------------------------------------------------- //
se::KernelBase* JitRtKernelsCache::Get(se::StreamExecutor* executor,
const char* data) {
Key key = {executor, data};
absl::MutexLock lock(&mutex_);
auto it = kernels_cache_.find(key);
if (it != kernels_cache_.end()) return it->second.get();
return nullptr;
se::KernelBase* JitRtKernelsCache::Set(se::StreamExecutor* executor,
const char* data,
std::unique_ptr<se::KernelBase> kernel) {
Key key = {executor, data};
absl::MutexLock lock(&mutex_);
auto it = kernels_cache_.find(key);
if (it != kernels_cache_.end()) return it->second.get();
auto emplaced = kernels_cache_.try_emplace(key, std::move(kernel));
return emplaced.first->second.get();
static se::DeviceMemoryBase GetDeviceAddress(jitrt::MemrefView& memref) {
uint64_t size = tfrt::GetHostSize(memref.dtype);
for (auto dim : memref.sizes) size *= dim;
return se::DeviceMemoryBase(, size);
static se::DeviceMemoryBase GetDeviceAddress(jitrt::FlatMemrefView& memref) {
return se::DeviceMemoryBase(, memref.size_in_bytes);
// -------------------------------------------------------------------------- //
const GemmConfig* JitRtGemmConfigCache::Get(int64_t uid) {
absl::MutexLock lock(&mutex_);
auto it = configs_.find(uid);
if (it != configs_.end()) return &it->second;
return nullptr;
const GemmConfig* JitRtGemmConfigCache::Set(int64_t uid, GemmConfig config) {
absl::MutexLock lock(&mutex_);
auto it = configs_.find(uid);
if (it != configs_.end()) return &it->second;
auto emplaced = configs_.try_emplace(uid, std::move(config));
return &emplaced.first->second;
// -------------------------------------------------------------------------- //
static PrimitiveType ToPrimitiveType(tfrt::DType dtype) {
switch (dtype) {
case tfrt::DType::F32:
return PrimitiveType::F32;
case tfrt::DType::F64:
return PrimitiveType::F64;
LOG(FATAL) << "Unsupported data type: " << dtype;
static Shape ToShape(const jitrt::MemrefView& memref) {
PrimitiveType type = ToPrimitiveType(memref.dtype);
return ShapeUtil::MakeShape(type, memref.sizes);
static StatusOr<GemmConfig> GetGemmConfig(
const DebugOptions* debug_options, const jitrt::MemrefView& lhs,
const jitrt::MemrefView& rhs, const jitrt::MemrefView& out,
int64_t algorithm, double alpha_imag, double alpha_real,
ArrayRef<int64_t> lhs_batch, ArrayRef<int64_t> lhs_contract,
ArrayRef<int64_t> rhs_batch, ArrayRef<int64_t> rhs_contract,
llvm::Optional<double> beta = llvm::None) {
return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
rhs_batch, rhs_contract, ToShape(out), alpha_real,
alpha_imag, beta.getValueOr(0.0), algorithm,
// -------------------------------------------------------------------------- //
namespace {
struct LaunchFunc {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtKernelsCache* kernels_cache,
int32_t grid_size_x, int32_t grid_size_y,
int32_t grid_size_z, int32_t block_size_x,
int32_t block_size_y, int32_t block_size_z,
CustomCall::RemainingArgs args, StringRef ptx,
StringRef name) const;
static LaunchFunc Handler() { return LaunchFunc(); }
} // namespace
LogicalResult LaunchFunc::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtKernelsCache* kernels_cache, int32_t grid_size_x, int32_t grid_size_y,
int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y,
int32_t block_size_z, CustomCall::RemainingArgs args, StringRef ptx,
StringRef name) const {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
LaunchDimensions launch_dimensions(
{grid_size_x, grid_size_y, grid_size_z},
{block_size_x, block_size_y, block_size_z});
se::KernelBase* kernel = kernels_cache->Get(executor,;
// If kernel does not exists create it from the ptx.
if (kernel == nullptr) {
auto created = CreateKernel(name, args.size(),, {}, executor);
if (!created.ok()) return failure();
kernel = kernels_cache->Set(executor,, std::move(*created));
VLOG(3) << "Launching " << kernel->name();
absl::InlinedVector<se::DeviceMemoryBase, 4> buffer_args;
// Add MemRef arguments as buffer arguments.
for (unsigned i = 0; i < args.size(); ++i) {
auto memref = args.get<jitrt::FlatMemrefView>(i);
if (failed(memref)) return failure();
// Execute device kernel on a main stream.
auto executed =
ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, stream);
if (!executed.ok()) return failure();
return success();
static bool LaunchFunc(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.func.launch")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<int32_t>() // grid_size_x
.Arg<int32_t>() // grid_size_y
.Arg<int32_t>() // grid_size_z
.Arg<int32_t>() // block_size_x
.Arg<int32_t>() // block_size_y
.Arg<int32_t>() // block_size_x
.RemainingArgs() // args
return succeeded(handler->call(args, attrs, Executable::GetUserData(ctx)));
// -------------------------------------------------------------------------- //
namespace {
struct Gemm {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
JitRtGemmConfigCache* configs, jitrt::MemrefView lhs,
jitrt::MemrefView rhs, jitrt::MemrefView out,
int64_t algorithm, double alpha_imag,
double alpha_real, ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract,
ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract, int64_t uid) const;
static Gemm Handler() { return Gemm(); }
} // namespace
LogicalResult Gemm::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, JitRtGemmConfigCache* configs,
jitrt::MemrefView lhs, jitrt::MemrefView rhs, jitrt::MemrefView out,
int64_t algorithm, double alpha_imag, double alpha_real,
ArrayRef<int64_t> lhs_batch, ArrayRef<int64_t> lhs_contract,
ArrayRef<int64_t> rhs_batch, ArrayRef<int64_t> rhs_contract,
int64_t uid) const {
se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
se::DeviceMemoryBase output_data = GetDeviceAddress(out);
se::OwningScratchAllocator<> scratch_allocator(run_options->device_ordinal(),
VLOG(3) << "Running GEMM";
se::Stream* stream = run_options->stream();
// Find the gemm config for this instance of operation based on uid.
const GemmConfig* config = configs->Get(uid);
if (config == nullptr) {
auto cfg = GetGemmConfig(debug_options, lhs, rhs, out, algorithm,
alpha_imag, alpha_real, lhs_batch, lhs_contract,
rhs_batch, rhs_contract);
if (!cfg.ok()) return failure();
config = configs->Set(uid, std::move(*cfg));
auto executed = RunGemm(*config, lhs_data, rhs_data, output_data, stream,
&scratch_allocator, nullptr);
if (!executed.ok()) return failure();
return success();
static bool Gemm(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::MemrefView>() // lhs
.Arg<jitrt::MemrefView>() // rhs
.Arg<jitrt::MemrefView>() // out
return succeeded(handler->call(args, attrs, Executable::GetUserData(ctx)));
// -------------------------------------------------------------------------- //
namespace {
struct GemmBias {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
JitRtGemmConfigCache* configs, jitrt::MemrefView lhs,
jitrt::MemrefView rhs, jitrt::MemrefView out,
jitrt::MemrefView bias, int64_t algorithm,
double alpha_imag, double alpha_real, double beta,
ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract,
ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract, int64_t uid) const;
static GemmBias Handler() { return GemmBias(); }
} // namespace
LogicalResult GemmBias::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, JitRtGemmConfigCache* configs,
jitrt::MemrefView lhs, jitrt::MemrefView rhs, jitrt::MemrefView out,
jitrt::MemrefView bias, int64_t algorithm, double alpha_imag,
double alpha_real, double beta, ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract, int64_t uid) const {
se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
se::DeviceMemoryBase output_data = GetDeviceAddress(out);
se::DeviceMemoryBase bias_data = GetDeviceAddress(bias);
se::OwningScratchAllocator<> scratch_allocator(run_options->device_ordinal(),
VLOG(3) << "Running GEMM + Bias [beta=" << beta << "]";
se::Stream* stream = run_options->stream();
// Find the gemm config for this instance of operation based on uid.
const GemmConfig* config = configs->Get(uid);
if (config == nullptr) {
auto cfg = GetGemmConfig(debug_options, lhs, rhs, out, algorithm,
alpha_imag, alpha_real, lhs_batch, lhs_contract,
rhs_batch, rhs_contract, beta);
if (!cfg.ok()) return failure();
config = configs->Set(uid, std::move(*cfg));
// Copy bias to the output buffer of they are different.
if ( !=
stream->ThenMemcpy(&output_data, bias_data, bias_data.size());
auto executed = RunGemm(*config, lhs_data, rhs_data, output_data, stream,
&scratch_allocator, nullptr);
if (!executed.ok()) return failure();
return success();
static bool GemmBias(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::MemrefView>() // lhs
.Arg<jitrt::MemrefView>() // rhs
.Arg<jitrt::MemrefView>() // out
.Arg<jitrt::MemrefView>() // bias
return succeeded(handler->call(args, attrs, Executable::GetUserData(ctx)));
// -------------------------------------------------------------------------- //
SymbolMap JitRtCustomCallsSymbolMap(MangleAndInterner mangle) {
SymbolMap symbol_map;
auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
bind("xla.gpu.func.launch", &xla::gpu::LaunchFunc);
bind("xla.gpu.gemm", &xla::gpu::Gemm);
bind("xla.gpu.gemm.bias", &xla::gpu::GemmBias);
return symbol_map;
} // namespace gpu
} // namespace xla