blob: 3bf5d87a6d38f5036245708de571162026904a29 [file] [log] [blame]
// Copyright 2022 The TensorFlow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include <numeric>
#include <utility>
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
#include "tfrt/jitrt/custom_call.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
xla::gpu::JitRtKernelsCache);
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
xla::gpu::JitRtGemmConfigCache);
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
xla::gpu::JitRtCollectiveSupport);
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
xla::gpu::JitRtAsyncCollectiveSupport);
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
const xla::ServiceExecutableRunOptions);
TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
const xla::DebugOptions);
namespace xla {
namespace gpu {
using llvm::ArrayRef;
using llvm::Optional;
using mlir::failure;
using mlir::FailureOr;
using mlir::LogicalResult;
using mlir::StringRef;
using mlir::succeeded;
using mlir::success;
using tfrt::jitrt::CustomCall;
using tfrt::jitrt::DirectCustomCallLibrary;
using tfrt::jitrt::Executable;
namespace se = ::stream_executor;
namespace jitrt = ::tfrt::jitrt;
namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
namespace mhlo = ::mlir::mhlo;
namespace runtime = ::tfrt::jitrt::runtime;
// Disable all CustomCall checks in optimized build.
static constexpr CustomCall::RuntimeChecks RuntimeChecks() {
#if defined(NDEBUG)
return CustomCall::RuntimeChecks::kNone;
#else
return CustomCall::RuntimeChecks::kDefault;
#endif
}
// -------------------------------------------------------------------------- //
void PopulateLmhloToXlaAttrEncoding(
jitrt::CustomCallAttrEncodingSet& encoding) {
encoding.Add<
jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
se::dnn::ActivationMode>>(
[](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
return ConvertConvActivationMode(value).value();
});
encoding.Add<
jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
[](mhlo::FftType value) -> se::fft::Type {
switch (value) {
case mhlo::FftType::FFT:
return se::fft::Type::kC2CForward;
case mhlo::FftType::IFFT:
return se::fft::Type::kC2CInverse;
case mhlo::FftType::RFFT:
return se::fft::Type::kR2C;
case mhlo::FftType::IRFFT:
return se::fft::Type::kC2R;
default:
return se::fft::Type::kInvalid;
}
});
}
// -------------------------------------------------------------------------- //
se::KernelBase* JitRtKernelsCache::Get(se::StreamExecutor* executor,
const char* data, StringRef name) {
Key key(executor, data, name);
absl::MutexLock lock(&mutex_);
auto it = kernels_cache_.find(key);
if (it != kernels_cache_.end()) return it->second.get();
return nullptr;
}
se::KernelBase* JitRtKernelsCache::Set(se::StreamExecutor* executor,
const char* data, StringRef name,
std::unique_ptr<se::KernelBase> kernel) {
Key key(executor, data, name);
absl::MutexLock lock(&mutex_);
auto it = kernels_cache_.find(key);
if (it != kernels_cache_.end()) return it->second.get();
auto emplaced = kernels_cache_.try_emplace(key, std::move(kernel));
return emplaced.first->second.get();
}
template <typename MemrefArg>
static se::DeviceMemoryBase GetDeviceAddress(MemrefArg& memref) {
uint64_t size = tfrt::GetHostSize(memref.dtype);
for (auto dim : memref.sizes) size *= dim;
return se::DeviceMemoryBase(memref.data, size);
}
static se::DeviceMemoryBase GetDeviceAddress(jitrt::FlatMemrefView& memref) {
return se::DeviceMemoryBase(memref.data, memref.size_in_bytes);
}
// -------------------------------------------------------------------------- //
const GemmConfig* JitRtGemmConfigCache::Get(int64_t uid) {
absl::MutexLock lock(&mutex_);
auto it = configs_.find(uid);
if (it != configs_.end()) return &it->second;
return nullptr;
}
const GemmConfig* JitRtGemmConfigCache::Set(int64_t uid, GemmConfig config) {
absl::MutexLock lock(&mutex_);
auto it = configs_.find(uid);
if (it != configs_.end()) return &it->second;
auto emplaced = configs_.try_emplace(uid, std::move(config));
return &emplaced.first->second;
}
// -------------------------------------------------------------------------- //
JitRtAsyncCollectiveSupport::JitRtAsyncCollectiveSupport(
se::Stream* async_comm_stream)
: async_comm_stream_(async_comm_stream) {}
Status JitRtCollectiveSupport::MaybeBlockAfterFirstRun(int32_t uid,
int32_t device_ordinal,
se::Stream* stream) {
bool block = [&] {
absl::MutexLock lock(&mutex_);
return executed_.try_emplace(Key(uid, device_ordinal), true).second;
}();
return block ? stream->BlockHostUntilDone() : Status::OK();
}
FailureOr<se::Event> JitRtAsyncCollectiveSupport::PopEvent(
int32_t uid, int32_t device_ordinal) {
const int64_t key = EventKey(uid, device_ordinal);
absl::MutexLock lock(&mutex_);
auto it = done_events_.find(key);
if (it == done_events_.end()) return failure();
se::Event done_event = std::move(it->second);
done_events_.erase(it);
return done_event;
}
LogicalResult JitRtAsyncCollectiveSupport::PushEvent(int32_t uid,
int32_t device_ordinal,
se::Event done_event) {
const int64_t key = EventKey(uid, device_ordinal);
absl::MutexLock lock(&mutex_);
auto result = done_events_.try_emplace(key, std::move(done_event));
if (!result.second) return failure(); // done event has not been consumed
return success();
}
// -------------------------------------------------------------------------- //
static PrimitiveType ToPrimitiveType(tfrt::DType dtype) {
switch (dtype) {
// Unsigned integer types.
case tfrt::DType::UI8:
return PrimitiveType::U8;
case tfrt::DType::UI16:
return PrimitiveType::U16;
case tfrt::DType::UI32:
return PrimitiveType::U32;
case tfrt::DType::UI64:
return PrimitiveType::U64;
// Signed integer types.
case tfrt::DType::I1:
return PrimitiveType::PRED;
case tfrt::DType::I8:
return PrimitiveType::S8;
case tfrt::DType::I16:
return PrimitiveType::S16;
case tfrt::DType::I32:
return PrimitiveType::S32;
case tfrt::DType::I64:
return PrimitiveType::S64;
// Floating point types.
case tfrt::DType::F16:
return PrimitiveType::F16;
case tfrt::DType::F32:
return PrimitiveType::F32;
case tfrt::DType::F64:
return PrimitiveType::F64;
case tfrt::DType::BF16:
return PrimitiveType::BF16;
// Complex types.
case tfrt::DType::Complex64:
return PrimitiveType::C64;
case tfrt::DType::Complex128:
return PrimitiveType::C128;
default:
LOG(FATAL) << "Unsupported data type: " << dtype;
}
}
static Shape ToShape(const jitrt::StridedMemrefView& memref) {
PrimitiveType type = ToPrimitiveType(memref.dtype);
// Recover `minor_to_major` dimensions permutation from strides.
auto indexed_strides_range =
llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) {
return std::pair<int64_t, size_t>{pair.value(), pair.index()};
});
auto indexed_strides = llvm::to_vector(indexed_strides_range);
llvm::stable_sort(indexed_strides);
llvm::SmallVector<int64_t> minor_to_major;
minor_to_major.reserve(indexed_strides.size());
for (auto& pair : indexed_strides) minor_to_major.push_back(pair.second);
return ShapeUtil::MakeShapeWithLayout(type, memref.sizes, minor_to_major);
}
static StatusOr<GemmConfig> GetGemmConfig(
const jitrt::StridedMemrefView& lhs, const jitrt::StridedMemrefView& rhs,
const jitrt::StridedMemrefView& out, int64_t algorithm, double alpha_real,
double alpha_imag, double beta, ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract) {
return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
rhs_batch, rhs_contract, ToShape(out), alpha_real,
alpha_imag, beta, algorithm,
se::blas::kDefaultComputePrecision);
}
// -------------------------------------------------------------------------- //
#if XLA_ENABLE_XCCL
FailureOr<NcclComm::Lock> GetNcclComm(const NcclExecuteParams& params,
int64_t group_mode, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) {
// TODO(b/233930690): Pass the attribute below as a nested array.
// Pass an array of arrays using two vectors; one specifying all the values
// and another specifying the (ending) offsets of each array in the other
// vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into
// offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90].
std::vector<ReplicaGroup> replica_groups;
int i = 0;
for (int64_t replica_group_end : replica_group_offsets) {
ReplicaGroup replica_group;
while (i < replica_group_end)
replica_group.add_replica_ids(replica_group_values[i++]);
replica_groups.push_back(replica_group);
}
auto comm =
LockNcclComm(params, replica_groups,
static_cast<CollectiveOpGroupMode>(group_mode), op_id);
if (comm.ok()) return std::move(comm.value());
return failure();
}
#endif // XLA_ENABLE_XCCL
FailureOr<std::vector<DeviceBufferPair>> GetDeviceBufferPairs(
CustomCall::RemainingArgs& args) {
// Add MemRef arguments as buffer arguments.
const int buffer_pairs = args.size() / 2;
std::vector<DeviceBufferPair> device_buffers;
device_buffers.reserve(buffer_pairs);
for (int i = 0; i < buffer_pairs; ++i) {
auto source = args.get<jitrt::StridedMemrefView>(i);
auto destination = args.get<jitrt::StridedMemrefView>(i + buffer_pairs);
if (failed(source) || failed(destination)) {
// Unsupported argument type.
return failure();
}
int element_count = 1;
for (int size : source->sizes) element_count *= size;
device_buffers.emplace_back(DeviceBufferPair{
ToPrimitiveType(source->dtype), element_count,
GetDeviceAddress(*source), GetDeviceAddress(*destination)});
}
return device_buffers;
}
// -------------------------------------------------------------------------- //
namespace {
struct LaunchFunc {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtKernelsCache* kernels_cache,
int32_t grid_size_x, int32_t grid_size_y,
int32_t grid_size_z, int32_t block_size_x,
int32_t block_size_y, int32_t block_size_z,
CustomCall::RemainingArgs args, StringRef ptx,
StringRef name) const;
static LaunchFunc Handler() { return LaunchFunc(); }
};
} // namespace
LogicalResult LaunchFunc::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtKernelsCache* kernels_cache, int32_t grid_size_x, int32_t grid_size_y,
int32_t grid_size_z, int32_t block_size_x, int32_t block_size_y,
int32_t block_size_z, CustomCall::RemainingArgs args, StringRef ptx,
StringRef name) const {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
LaunchDimensions launch_dimensions(
{grid_size_x, grid_size_y, grid_size_z},
{block_size_x, block_size_y, block_size_z});
se::KernelBase* kernel = kernels_cache->Get(executor, ptx.data(), name);
// If kernel does not exists create it from the ptx.
if (kernel == nullptr) {
auto created = CreateKernel(absl::string_view(name.data(), name.size()),
args.size(), ptx.data(), {}, executor);
if (!created.ok()) return failure();
kernel =
kernels_cache->Set(executor, ptx.data(), name, std::move(*created));
}
VLOG(3) << "Launching " << kernel->name();
absl::InlinedVector<se::DeviceMemoryBase, 4> buffer_args;
buffer_args.reserve(args.size());
// Add MemRef arguments as buffer arguments.
for (unsigned i = 0; i < args.size(); ++i) {
// Simple row major memref passed as shapeless buffer.
auto memref = args.get<jitrt::FlatMemrefView>(i);
if (succeeded(memref)) {
buffer_args.emplace_back(GetDeviceAddress(*memref));
continue;
}
// Memref layout must be encoded in the compiled device kernel, so we don't
// have to pass strides or minor to major dimensions order to the kernel.
auto strided = args.get<jitrt::StridedMemrefView>(i);
if (succeeded(strided)) {
buffer_args.emplace_back(GetDeviceAddress(*strided));
continue;
}
// Unsupported argument type.
return failure();
}
// Execute device kernel on a main stream.
auto executed =
ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, stream);
if (!executed.ok()) return failure();
return success();
}
static bool LaunchFunc(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.func.launch")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtKernelsCache*>()
.Arg<int32_t>() // grid_size_x
.Arg<int32_t>() // grid_size_y
.Arg<int32_t>() // grid_size_z
.Arg<int32_t>() // block_size_x
.Arg<int32_t>() // block_size_y
.Arg<int32_t>() // block_size_x
.RemainingArgs() // args
.Attr<StringRef>("ptx")
.Attr<StringRef>("kernel")
.To<RuntimeChecks()>(LaunchFunc::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Gemm {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, JitRtGemmConfigCache* configs,
jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs,
jitrt::StridedMemrefView out, int64_t algorithm, double alpha_real,
double alpha_imag, double beta, ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract, int64_t uid) const;
static Gemm Handler() { return Gemm(); }
};
} // namespace
LogicalResult Gemm::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, JitRtGemmConfigCache* configs,
jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs,
jitrt::StridedMemrefView out, int64_t algorithm, double alpha_real,
double alpha_imag, double beta, ArrayRef<int64_t> lhs_batch,
ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
ArrayRef<int64_t> rhs_contract, int64_t uid) const {
se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
se::DeviceMemoryBase output_data = GetDeviceAddress(out);
VLOG(3) << "Running GEMM";
se::Stream* stream = run_options->stream();
// Find the gemm config for this instance of operation based on uid.
const GemmConfig* config = configs->Get(uid);
if (config == nullptr) {
auto cfg =
GetGemmConfig(lhs, rhs, out, algorithm, alpha_real, alpha_imag, beta,
lhs_batch, lhs_contract, rhs_batch, rhs_contract);
if (!cfg.ok()) return failure();
config = configs->Set(uid, std::move(*cfg));
}
Status executed = [&]() -> Status {
return RunGemm(*config, lhs_data, rhs_data, output_data, stream);
}();
if (!executed.ok()) return failure();
return success();
}
static bool Gemm(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.gemm")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.UserData<JitRtGemmConfigCache*>()
.Arg<jitrt::StridedMemrefView>() // lhs
.Arg<jitrt::StridedMemrefView>() // rhs
.Arg<jitrt::StridedMemrefView>() // out
.Attr<int64_t>("algorithm")
.Attr<double>("alpha_real")
.Attr<double>("alpha_imag")
.Attr<double>("beta")
.Attr<ArrayRef<int64_t>>("lhs_batching_dimensions")
.Attr<ArrayRef<int64_t>>("lhs_contracting_dimensions")
.Attr<ArrayRef<int64_t>>("rhs_batching_dimensions")
.Attr<ArrayRef<int64_t>>("rhs_contracting_dimensions")
.Attr<int64_t>("uid")
.To<RuntimeChecks()>(Gemm::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
// TODO(ezhulenev): We need to find a better way to pass structured attributes
// to JitRt custom calls.
// TODO(ezhulenev): Add caching layer for convolution configs and runners.
namespace {
struct InputDimensions {
int64_t input_batch_dim;
int64_t input_feature_dim;
ArrayRef<int64_t> input_spatial_dims;
};
struct KernelDimensions {
int64_t kernel_in_feature_dim;
int64_t kernel_out_feature_dim;
ArrayRef<int64_t> kernel_spatial_dims;
};
struct OutputDimensions {
int64_t output_batch_dim;
int64_t output_feature_dim;
ArrayRef<int64_t> output_spatial_dims;
};
struct Window {
ArrayRef<int64_t> window_strides;
ArrayRef<int64_t> padding;
ArrayRef<int64_t> lhs_dilation;
ArrayRef<int64_t> rhs_dilation;
ArrayRef<int64_t> window_reversal;
};
struct BackendConfig {
int64_t algorithm;
bool tensor_ops_enabled;
bool is_cudnn_frontend;
ArrayRef<int64_t> knob_ids;
ArrayRef<int64_t> knob_values;
ArrayRef<int64_t> operand_0_layout;
ArrayRef<int64_t> operand_1_layout;
ArrayRef<int64_t> result_layout;
int64_t workspace_size;
};
struct ConvAttrs {
int64_t feature_group_count;
double result_scale;
};
struct FusedConvAttrs {
se::dnn::ActivationMode activation_mode;
};
struct SideInputAttrs {
double side_input_scale;
};
} // namespace
static GpuConvDescriptor GetConvDescriptor(
CudnnConvKind kind,
// Arguments
jitrt::StridedMemrefView operand0, jitrt::StridedMemrefView operand1,
jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
// Attributes
InputDimensions i, KernelDimensions k, OutputDimensions o, Window w,
BackendConfig b, ConvAttrs attrs,
// Conv-specific arguments and attributes
Optional<FusedConvAttrs> fused = llvm::None,
Optional<SideInputAttrs> side_input = llvm::None) {
// Build a convolution descriptor from the attributes.
GpuConvDescriptor descriptor;
descriptor.kind = kind;
// Apply backend config layout to the shape.
auto apply_layout = [](jitrt::StridedMemrefView& memref,
ArrayRef<int64_t> minor_to_major) {
Shape shape = ToShape(memref);
return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
shape.dimensions(), minor_to_major);
};
descriptor.operand0_shape = apply_layout(operand0, b.operand_0_layout);
descriptor.operand1_shape = apply_layout(operand1, b.operand_1_layout);
descriptor.result_shape = apply_layout(output, b.result_layout);
// Set up convolution dimensions numbers.
ConvolutionDimensionNumbers dns;
dns.set_input_batch_dimension(i.input_batch_dim);
dns.set_input_feature_dimension(i.input_feature_dim);
dns.set_kernel_input_feature_dimension(k.kernel_in_feature_dim);
dns.set_kernel_output_feature_dimension(k.kernel_out_feature_dim);
dns.set_output_batch_dimension(o.output_batch_dim);
dns.set_output_feature_dimension(o.output_feature_dim);
for (int64_t d : i.input_spatial_dims) dns.add_input_spatial_dimensions(d);
for (int64_t d : k.kernel_spatial_dims) dns.add_kernel_spatial_dimensions(d);
for (int64_t d : o.output_spatial_dims) dns.add_output_spatial_dimensions(d);
descriptor.dnums = std::move(dns);
// Put together convolution window config.
for (auto index : llvm::seq<int>(0, w.window_strides.size())) {
WindowDimension* dim = descriptor.window.add_dimensions();
// Window size for a convolution is the same as the kernel size.
// Kernel size of the convolution is operand1_shape. We need to look at
// the convolution dimension numbers kernel spatial dimensions to get
// the window size.
int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
dim->set_stride(w.window_strides[index]);
dim->set_padding_low(w.padding[index]);
dim->set_padding_high(w.padding[index]);
dim->set_base_dilation(w.lhs_dilation[index]);
dim->set_window_dilation(w.rhs_dilation[index]);
dim->set_window_reversal(w.window_reversal[index]);
}
descriptor.scratch_size = scratch.size_in_bytes;
descriptor.feature_group_count = attrs.feature_group_count;
descriptor.backend_config.set_conv_result_scale(attrs.result_scale);
// Set up convolution algorigthm.
auto* algo = descriptor.backend_config.mutable_algorithm();
algo->set_algo_id(b.algorithm);
algo->set_math_type(b.tensor_ops_enabled
? se::dnn::AlgorithmProto::TENSOR_OP_MATH
: se::dnn::AlgorithmProto::DEFAULT_MATH);
algo->set_is_cudnn_frontend(b.is_cudnn_frontend);
if (b.workspace_size >= 0)
algo->mutable_workspace_size()->set_value(b.workspace_size);
for (unsigned i = 0; i < b.knob_ids.size(); ++i) {
algo->mutable_tuning_knobs()->insert({b.knob_ids[i], b.knob_values[i]});
}
// Set attributes specific for fused convolutions.
if (fused.has_value())
descriptor.backend_config.set_activation_mode(fused->activation_mode);
// Set attributes specific for convolutions with side input.
if (side_input.has_value())
descriptor.backend_config.set_side_input_scale(
side_input->side_input_scale);
return descriptor;
}
namespace {
struct Conv {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, jitrt::StridedMemrefView operand0,
jitrt::StridedMemrefView operand1, Optional<jitrt::FlatMemrefView> bias,
Optional<jitrt::StridedMemrefView> side_input,
jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
// Convolution input dimensions numbers
int64_t input_batch_dim, int64_t input_feature_dim,
ArrayRef<int64_t> input_spatial_dims,
// Convolution kernel dimensions numbers
int64_t kernel_in_feature_dim, int64_t kernel_out_feature_dim,
ArrayRef<int64_t> kernel_spatial_dims,
// Output dimensions numbers
int64_t output_batch_dim, int64_t output_feature_dim,
ArrayRef<int64_t> output_spatial_dims,
// Window config
ArrayRef<int64_t> window_strides, ArrayRef<int64_t> padding,
ArrayRef<int64_t> lhs_dilation, ArrayRef<int64_t> rhs_dilation,
ArrayRef<int64_t> window_reversal,
// Backend config attributes
int64_t algorithm, bool tensor_ops_enabled, bool is_cudnn_frontend,
ArrayRef<int64_t> knob_ids, ArrayRef<int64_t> knob_values,
ArrayRef<int64_t> operand_0_layout, ArrayRef<int64_t> operand_1_layout,
ArrayRef<int64_t> result_layout, int64_t workspace_size,
// Remaining attributes
int64_t feature_group_count, double result_scale,
// Optional attributes for fused convolutions.
Optional<se::dnn::ActivationMode> activation_mode = llvm::None,
Optional<double> side_input_scale = llvm::None) const {
// Build config for optional attributes.
Optional<FusedConvAttrs> fused_attrs = llvm::None;
if (activation_mode.has_value()) fused_attrs = {*activation_mode};
Optional<SideInputAttrs> side_input_attrs = llvm::None;
if (side_input_scale.has_value()) side_input_attrs = {*side_input_scale};
// Prepare a descriptor for the XLA convolution.
GpuConvDescriptor descriptor = GetConvDescriptor(
kind, operand0, operand1, output, scratch,
{input_batch_dim, input_feature_dim, input_spatial_dims},
{kernel_in_feature_dim, kernel_out_feature_dim, kernel_spatial_dims},
{output_batch_dim, output_feature_dim, output_spatial_dims},
{window_strides, padding, lhs_dilation, rhs_dilation, window_reversal},
{algorithm, tensor_ops_enabled, is_cudnn_frontend, knob_ids,
knob_values, operand_0_layout, operand_1_layout, result_layout,
workspace_size},
{feature_group_count, result_scale}, fused_attrs, side_input_attrs);
// Convert descriptor to the Conv config.
StatusOr<GpuConvConfig> config = GetGpuConvConfig(descriptor, "");
if (!config.ok()) return failure();
// Prepare buffer arguments.
std::vector<se::DeviceMemoryBase> buffers = {GetDeviceAddress(operand0),
GetDeviceAddress(operand1)};
if (bias.has_value()) buffers.push_back(GetDeviceAddress(*bias));
if (side_input.has_value())
buffers.push_back(GetDeviceAddress(*side_input));
se::DeviceMemoryBase result_buffer = GetDeviceAddress(output);
se::DeviceMemoryBase scratch_buffer = GetDeviceAddress(scratch);
RunConvOptions opts;
// Create a runner for the given config.
MaybeFusedConvRunner runner(*config);
opts.runner_cache = &runner;
// Run the convolution.
auto st = RunGpuConv(*config, buffers, result_buffer, scratch_buffer,
run_options->stream(), opts);
if (!st.ok() || !run_options->stream()->ok()) return failure();
return success();
}
static Conv Handler(CudnnConvKind kind) { return Conv{kind}; }
CudnnConvKind kind;
};
} // namespace
// Adds custom call bindings for convolution operations.
template <typename... Ts>
static auto BindConvAttributes(jitrt::CustomCallBinding<Ts...> binding) {
return std::move(binding)
// Convolution dimensions numbers
.template Attr<int64_t>("input_batch_dim")
.template Attr<int64_t>("input_feature_dim")
.template Attr<ArrayRef<int64_t>>("input_spatial_dims")
// Convolution kernel dimensions
.template Attr<int64_t>("kernel_in_feature_dim")
.template Attr<int64_t>("kernel_out_feature_dim")
.template Attr<ArrayRef<int64_t>>("kernel_spatial_dims")
// Output dimensions
.template Attr<int64_t>("output_batch_dim")
.template Attr<int64_t>("output_feature_dim")
.template Attr<ArrayRef<int64_t>>("output_spatial_dims")
// Window config
.template Attr<ArrayRef<int64_t>>("window_strides")
.template Attr<ArrayRef<int64_t>>("padding")
.template Attr<ArrayRef<int64_t>>("lhs_dilation")
.template Attr<ArrayRef<int64_t>>("rhs_dilation")
.template Attr<ArrayRef<int64_t>>("window_reversal")
// Backend config attributes
.template Attr<int64_t>("algorithm")
.template Attr<bool>("tensor_ops_enabled")
.template Attr<bool>("is_cudnn_frontend")
.template Attr<ArrayRef<int64_t>>("knob_ids")
.template Attr<ArrayRef<int64_t>>("knob_values")
.template Attr<ArrayRef<int64_t>>("operand_0_layout")
.template Attr<ArrayRef<int64_t>>("operand_1_layout")
.template Attr<ArrayRef<int64_t>>("result_layout")
.template Attr<int64_t>("workspace_size")
// Remaining attributes.
.template Attr<int64_t>("feature_group_count")
.template Attr<double>("result_scale");
}
template <CudnnConvKind kind>
static bool ConvFn(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
BindConvAttributes(CustomCall::Bind("xla.gpu.conv")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::StridedMemrefView>() // operand0
.Arg<jitrt::StridedMemrefView>() // operand1
.Value(CustomCall::None) // bias
.Value(CustomCall::None) // side_input
.Arg<jitrt::StridedMemrefView>() // output
.Arg<jitrt::FlatMemrefView>() // scratch
)
.To(Conv::Handler(kind))
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
template <CudnnConvKind kind>
static bool ConvFusedFn(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler =
BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::StridedMemrefView>() // operand0
.Arg<jitrt::StridedMemrefView>() // operand1
.Arg<jitrt::FlatMemrefView>() // bias
.Value(CustomCall::None) // side_input
.Arg<jitrt::StridedMemrefView>() // output
.Arg<jitrt::FlatMemrefView>() // scratch
)
.Attr<se::dnn::ActivationMode>("activation_mode")
.To(Conv::Handler(kind))
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
template <CudnnConvKind kind>
static bool ConvFuseSideInputdFn(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler =
BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::StridedMemrefView>() // operand0
.Arg<jitrt::StridedMemrefView>() // operand1
.Arg<jitrt::FlatMemrefView>() // bias
.Arg<jitrt::StridedMemrefView>() // side_input
.Arg<jitrt::StridedMemrefView>() // output
.Arg<jitrt::FlatMemrefView>() // scratch
)
.Attr<se::dnn::ActivationMode>("activation_mode")
.Attr<double>("side_input_scale")
.To(Conv::Handler(kind))
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Infeed {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
CustomCall::RemainingArgs args,
StringRef config) const;
static Infeed Handler() { return Infeed(); }
};
} // namespace
LogicalResult Infeed::operator()(const ServiceExecutableRunOptions* run_options,
CustomCall::RemainingArgs args,
StringRef config) const {
VLOG(3) << "Infeeding to GPU";
se::Stream* stream = run_options->stream();
ShapeTree<se::ScopedDeviceMemory<uint8_t>> source_buffers =
GetOrCreateInfeedManager(stream->parent())->BlockingGetNextDestination();
// Check that we have correct number of arguments.
if (args.size() != source_buffers.leaf_count()) return failure();
// TODO(ezhulenev): Report human-readable error messages through errors.
size_t index = 0;
for (auto& source : source_buffers.leaves()) {
// Get the destination buffer.
auto dest = args.get<jitrt::StridedMemrefView>(index);
if (failed(dest)) return failure();
// Get the source buffer shape.
const Shape& source_shape =
ShapeUtil::GetSubshape(source_buffers.shape(), source.first);
// Check that destination shape matches the source shape.
// TODO(ezhulenev): Report human-readable error similar to infeed_thunk.
Shape dest_shape = ToShape(*dest);
if (!ShapeUtil::Equal(dest_shape, source_shape)) return failure();
se::DeviceMemoryBase dest_address = GetDeviceAddress(*dest);
se::ScopedDeviceMemory<uint8_t>& buffer = source.second;
stream->ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size());
++index;
}
// TODO(ezhulenev): Make this function async?
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) return failure();
VLOG(3) << "Infeeding to GPU complete";
return success();
}
static bool Infeed(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.infeed")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<CustomCall::RemainingArgs>() // args
.Attr<StringRef>("config")
.To<RuntimeChecks()>(Infeed::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Outfeed {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
CustomCall::RemainingArgs args,
StringRef config) const;
static Outfeed Handler() { return Outfeed(); }
};
} // namespace
LogicalResult Outfeed::operator()(
const ServiceExecutableRunOptions* run_options,
CustomCall::RemainingArgs args, StringRef config) const {
VLOG(3) << "Outfeeding from GPU";
se::Stream* stream = run_options->stream();
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream->parent());
ShapeTree<std::unique_ptr<OutfeedBuffer>>* dest_buffers =
outfeed_manager->BlockingGetNextDestination();
// Check that we have correct number of arguments.
if (args.size() != dest_buffers->leaf_count()) return failure();
size_t index = 0;
for (auto& dest : dest_buffers->leaves()) {
// Get the source buffer.
auto source = args.get<jitrt::StridedMemrefView>(index);
if (failed(source)) return failure();
// Get the source buffer shape.
const Shape& dest_shape =
ShapeUtil::GetSubshape(dest_buffers->shape(), dest.first);
// Check that destination shape matches the source shape.
// TODO(ezhulenev): Report human-readable error similar to outfeed_thunk.
Shape source_shape = ToShape(*source);
if (!ShapeUtil::Equal(dest_shape, source_shape)) return failure();
se::DeviceMemoryBase source_address = GetDeviceAddress(*source);
std::unique_ptr<OutfeedBuffer>& buffer = dest.second;
// Schedule the memory transfer.
auto* dest_address = buffer->destination()->untyped_data();
stream->ThenMemcpy(dest_address, source_address, buffer->length())
.ThenDoHostCallback([&buffer]() { buffer->Done(); });
++index;
}
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) return failure();
VLOG(3) << "Outfeeding from GPU complete";
return success();
}
static bool Outfeed(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.outfeed")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<CustomCall::RemainingArgs>() // args
.Attr<StringRef>("config")
.To<RuntimeChecks()>(Outfeed::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
enum class MemcpyDirection { kDeviceToDevice, kDeviceToHost, kHostToDevice };
template <MemcpyDirection direction>
struct Memcpy {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView dst,
jitrt::FlatMemrefView src) const;
static Memcpy Handler() { return Memcpy(); }
};
} // namespace
template <MemcpyDirection direction>
LogicalResult Memcpy<direction>::operator()(
const ServiceExecutableRunOptions* run_options, jitrt::FlatMemrefView dst,
jitrt::FlatMemrefView src) const {
se::Stream* stream = run_options->stream();
if (dst.size_in_bytes != src.size_in_bytes) return failure();
switch (direction) {
case MemcpyDirection::kDeviceToDevice: {
se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
se::DeviceMemoryBase src_data = GetDeviceAddress(src);
stream->ThenMemcpy(&dst_data, src_data, src.size_in_bytes);
} break;
case MemcpyDirection::kDeviceToHost: {
se::DeviceMemoryBase src_data = GetDeviceAddress(src);
stream->ThenMemcpy(dst.data, src_data, src.size_in_bytes);
} break;
case MemcpyDirection::kHostToDevice: {
se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
stream->ThenMemcpy(&dst_data, src.data, src.size_in_bytes);
} break;
}
// TODO(ezhulenev): H2D and D2H memcpy instead of blocking the execution
// thread should return an async token that will become available when
// transfer is completed.
if (direction != MemcpyDirection::kDeviceToDevice) {
auto st = stream->BlockHostUntilDone();
if (!st.ok()) return failure();
}
return success();
}
template <MemcpyDirection direction>
static bool MemcpyFn(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<jitrt::FlatMemrefView>() // dst
.Arg<jitrt::FlatMemrefView>() // src
.To<RuntimeChecks()>(Memcpy<direction>::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Memset {
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView dst,
CustomCall::VariantArg constant) const;
static Memset Handler() { return Memset(); }
};
} // namespace
LogicalResult Memset::operator()(const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView dst,
CustomCall::VariantArg constant) const {
uint32_t pattern;
if (constant.isa<int32_t>())
pattern = *constant.get<int32_t>();
else if (constant.isa<float>())
pattern = reinterpret_cast<uint32_t&>(*constant.get<float>());
else
return failure();
se::Stream* stream = run_options->stream();
if (dst.size_in_bytes % 4 != 0) return failure();
se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
stream->ThenMemset32(&dst_data, pattern, dst.size_in_bytes);
return success();
}
static bool MemsetFn(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.memset")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<jitrt::FlatMemrefView>() // dst
.Arg<CustomCall::VariantArg>() // constant
.To<RuntimeChecks()>(Memset::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Fft {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
jitrt::StridedMemrefView input,
jitrt::StridedMemrefView output,
ArrayRef<int64_t> fft_length,
se::fft::Type fft_type) const;
static Fft Handler() { return Fft(); }
};
} // namespace
LogicalResult Fft::operator()(const ServiceExecutableRunOptions* run_options,
jitrt::StridedMemrefView input,
jitrt::StridedMemrefView output,
ArrayRef<int64_t> fft_length,
se::fft::Type fft_type) const {
// TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
FftPlanCache fft_plan_cache;
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
if (input.dtype == tfrt::DType::F64 ||
input.dtype == tfrt::DType::Complex128) {
// Adjust FFT type to reflect double precision.
switch (fft_type) {
case se::fft::Type::kC2CForward:
fft_type = se::fft::Type::kZ2ZForward;
break;
case se::fft::Type::kC2CInverse:
fft_type = se::fft::Type::kZ2ZInverse;
break;
case se::fft::Type::kR2C:
fft_type = se::fft::Type::kD2Z;
break;
case se::fft::Type::kC2R:
fft_type = se::fft::Type::kZ2D;
break;
default:
return failure();
}
}
auto st =
RunFft(GetDeviceAddress(input), ToShape(input), GetDeviceAddress(output),
ToShape(output), fft_type, fft_length, executor->device_ordinal(),
&fft_plan_cache, stream, run_options->allocator());
if (!st.ok()) return failure();
return success();
}
static bool Fft(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.fft")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<jitrt::StridedMemrefView>() // input
.Arg<jitrt::StridedMemrefView>() // output
.Attr<ArrayRef<int64_t>>("fft_length")
.Attr<se::fft::Type>("fft_type")
.To<RuntimeChecks()>(Fft::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct Cholesky {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
jitrt::MemrefView operand, jitrt::MemrefView a,
jitrt::MemrefView workspace, jitrt::MemrefView info,
int64_t batch_size, int64_t n, int64_t uplo) const;
static Cholesky Handler() { return Cholesky(); }
};
} // namespace
LogicalResult Cholesky::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, jitrt::MemrefView operand,
jitrt::MemrefView a, jitrt::MemrefView workspace, jitrt::MemrefView info,
int64_t batch_size, int64_t n, int64_t uplo) const {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand);
se::DeviceMemoryBase a_buffer = GetDeviceAddress(a);
se::DeviceMemoryBase workspace_buffer = GetDeviceAddress(workspace);
se::DeviceMemoryBase info_buffer = GetDeviceAddress(info);
VLOG(3) << "Running Cholesky";
se::Stream* stream = run_options->stream();
// Copy operand to the a buffer if they are different.
if (a.data != operand.data)
stream->ThenMemcpy(&a_buffer, operand_buffer, operand_buffer.size());
CholeskyParams params{
n, batch_size, static_cast<se::blas::UpperLower>(uplo),
a_buffer, workspace_buffer, info_buffer};
auto executed = RunCholesky(xla::gpu::PtxOptsFromDebugOptions(*debug_options),
ToPrimitiveType(operand.dtype), &params, stream);
if (!executed.ok()) return failure();
return success();
#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
return failure();
#endif
}
static bool Cholesky(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.cholesky")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::MemrefView>() // operand
.Arg<jitrt::MemrefView>() // a
.Arg<jitrt::MemrefView>() // workspace
.Arg<jitrt::MemrefView>() // info
.Attr<int64_t>("batch_size")
.Attr<int64_t>("n")
.Attr<int64_t>("uplo") // se::blas::UpperLower
.To<RuntimeChecks()>(Cholesky::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
// TODO(ezhulenev): Today XLA represents TriangularSolve as a "classic" XLA
// custom call operation, and we provide a thin adaptor from Xla custom call
// to JitRt custom call. Once we are fully migrated to JitRt exectuion, XLA
// compiler should directly emit properly typed TriangularSolve JitRt custom
// call (no need to pass config via the serialized string).
struct TriangularSolve {
// Adaptor from XlaCustomCall API to properly typed TriangularSolve handler.
static LogicalResult run(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
CustomCall::RemainingArgs args,
StringRef backend_config);
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
jitrt::StridedMemrefView a,
jitrt::StridedMemrefView b,
jitrt::StridedMemrefView result,
jitrt::FlatMemrefView temp, bool left_side,
bool lower, bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a) const;
static TriangularSolve Handler() { return TriangularSolve(); }
};
} // namespace
LogicalResult TriangularSolve::run(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, CustomCall::RemainingArgs args,
StringRef backend_config) {
TriangularSolve handler = TriangularSolve::Handler();
// We expect 4 memref argumets.
if (args.size() != 4) return failure();
// Check if all arguments have the correct type.
auto a = args.get<jitrt::StridedMemrefView>(0);
auto b = args.get<jitrt::StridedMemrefView>(1);
auto result = args.get<jitrt::StridedMemrefView>(2);
auto temp = args.get<jitrt::FlatMemrefView>(3);
if (failed(a) || failed(b) || failed(result) || failed(temp))
return failure();
// Parse backend config string.
TriangularSolveOptions opts;
if (!tensorflow::HumanReadableJsonToProto(backend_config.str(), &opts).ok())
return failure();
return handler(run_options, debug_options, *a, *b, *result, *temp,
opts.left_side(), opts.lower(), opts.unit_diagonal(),
opts.transpose_a());
}
LogicalResult TriangularSolve::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, jitrt::StridedMemrefView a,
jitrt::StridedMemrefView b, jitrt::StridedMemrefView result,
jitrt::FlatMemrefView temp, bool left_side, bool lower, bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a) const {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
se::Stream* stream = run_options->stream();
se::DeviceMemoryBase a_data = GetDeviceAddress(a);
se::DeviceMemoryBase b_data = GetDeviceAddress(b);
se::DeviceMemoryBase result_data = GetDeviceAddress(result);
se::DeviceMemoryBase temp_data = GetDeviceAddress(temp);
// Triangular solve is in-place on 'b', so copy 'b' to the output if they
// aren't the same buffer.
if (b.data != result.data)
stream->ThenMemcpy(&result_data, b_data, b_data.size());
Shape b_shape = ToShape(b);
int64_t m = b_shape.dimensions(b_shape.rank() - 2);
int64_t n = b_shape.dimensions(b_shape.rank() - 1);
int64_t batch_size = std::accumulate(
b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64_t{1},
[](int64_t a, int64_t b) { return a * b; });
PrimitiveType elem_type = ToPrimitiveType(b.dtype);
int64_t elem_size = ShapeUtil::ByteSizeOfPrimitiveType(elem_type);
int64_t a_batch_stride = left_side ? m * m * elem_size : n * n * elem_size;
int64_t b_batch_stride = m * n * elem_size;
using Side = se::blas::Side;
using Diagonal = se::blas::Diagonal;
using Transpose = se::blas::Transpose;
using UpperLower = se::blas::UpperLower;
// Convert custom call attributes to se::blas enums.
UpperLower uplo = lower ? UpperLower::kLower : UpperLower::kUpper;
Side side = left_side ? Side::kLeft : Side::kRight;
Diagonal diagonal = unit_diagonal ? Diagonal::kUnit : Diagonal::kNonUnit;
auto transpose = [&]() -> mlir::FailureOr<Transpose> {
switch (transpose_a) {
case TriangularSolveOptions::NO_TRANSPOSE:
return se::blas::Transpose::kNoTranspose;
case TriangularSolveOptions::TRANSPOSE:
return se::blas::Transpose::kTranspose;
case TriangularSolveOptions::ADJOINT:
return se::blas::Transpose::kConjugateTranspose;
default:
return failure();
}
}();
if (failed(transpose)) return failure();
auto st = RunTriangulatSolve(
a_data, result_data, temp_data, PtxOptsFromDebugOptions(*debug_options),
uplo, side, diagonal, *transpose, elem_type, batch_size, m, n,
a_batch_stride, b_batch_stride, stream);
if (!st.ok()) return failure();
return success();
#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
return failure();
#endif
}
// -------------------------------------------------------------------------- //
// Implements JitRt custom call that forward to the Xla Custom Call handler.
//
// Longer term all Xla custom calls probably should be directly implemented as
// JitRt custom calls. However for smooth migration from Thunks to JitRt we have
// to seamlessly support all current XLA users.
namespace {
struct XlaCustomCall {
using Stream = se::gpu::GpuStreamHandle;
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
CustomCall::RemainingArgs args,
StringRef call_target_name, int32_t api_version,
StringRef backend_config) const;
static XlaCustomCall Handler() { return XlaCustomCall(); }
};
} // namespace
LogicalResult XlaCustomCall::operator()(
const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options, CustomCall::RemainingArgs args,
StringRef call_target_name, int32_t api_version,
StringRef backend_config) const {
// Pattern match custom call to a few special cases, otherwise find the custom
// call handler regustered with the runtime.
if (call_target_name == kTriangularSolveCallTarget)
return TriangularSolve::run(run_options, debug_options, args,
backend_config);
// Find the Xla custom call handler.
auto& platform_name = run_options->stream()->parent()->platform()->Name();
void* call_target = CustomCallTargetRegistry::Global()->Lookup(
call_target_name.str(), platform_name);
if (!call_target) return failure();
// Prepare pointers to buffers to pass to the Xla custom call handler.
llvm::SmallVector<void*> buffers;
for (unsigned i = 0; i < args.size(); ++i) {
auto memref = args.get<jitrt::FlatMemrefView>(i);
if (failed(memref)) return failure();
// We use zero-sized memrefs to represent holes in custom calls with target
// arguments mapping (see `CustomCallTargetArgMapping`).
buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data);
}
// Original custom call API version that doesn't support returning status.
if (api_version == CustomCallApiVersion::API_VERSION_ORIGINAL) {
using XlaCustomCallType = void (*)(Stream, void**, const char*, size_t);
auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
buffers.data(), backend_config.data(),
backend_config.size());
return success();
}
// Xla Custom call API returning status.
if (api_version == CustomCallApiVersion::API_VERSION_STATUS_RETURNING) {
using XlaCustomCallType =
void (*)(Stream, void**, const char*, size_t, XlaCustomCallStatus*);
auto xla_call_target = reinterpret_cast<XlaCustomCallType>(call_target);
XlaCustomCallStatus custom_call_status;
xla_call_target(se::gpu::AsGpuStreamValue(run_options->stream()),
buffers.data(), backend_config.data(),
backend_config.size(), &custom_call_status);
if (auto message = CustomCallStatusGetMessage(&custom_call_status)) {
return failure();
} else {
return success();
}
}
return failure();
}
static bool CustomCall(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.Arg<jitrt::CustomCall::RemainingArgs>() // args
.Attr<StringRef>("call_target_name")
.Attr<int32_t>("api_version")
.Attr<StringRef>("backend_config")
.To<RuntimeChecks()>(XlaCustomCall::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// ------------------------------------------------------------------------- //
namespace {
struct AllReduce {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
CustomCall::RemainingArgs args, int32_t uid,
int64_t group_mode, int64_t op_id,
int64_t reduction_kind,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const;
static AllReduce Handler() { return AllReduce(); }
};
} // namespace
LogicalResult AllReduce::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
int32_t uid, int64_t group_mode, int64_t op_id, int64_t reduction_kind,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running AllReduce";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
auto executed = RunAllReduce(static_cast<ReductionKind>(reduction_kind),
*device_buffers, *stream, **comm);
if (!executed.ok()) return failure();
int32_t device_ordinal = stream->parent()->device_ordinal();
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
// NCCL disabled.
return failure();
#endif // XLA_ENABLE_XCCL
}
static bool AllReduce(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.all_reduce")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<int64_t>("op_id")
.Attr<int64_t>("reduction_kind") // ReductionKind
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.To<RuntimeChecks()>(AllReduce::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// ------------------------------------------------------------------------- //
namespace {
struct AllReduceStart {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtAsyncCollectiveSupport* async_collectives,
CustomCall::RemainingArgs args, int64_t group_mode,
int64_t op_id, int64_t reduction_kind,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values,
int32_t uid) const;
static AllReduceStart Handler() { return AllReduceStart(); }
};
} // namespace
LogicalResult AllReduceStart::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtAsyncCollectiveSupport* async_collectives,
CustomCall::RemainingArgs args, int64_t group_mode, int64_t op_id,
int64_t reduction_kind, ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values, int32_t uid) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running AllReduceStart";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
// Wait until compute inputs are ready.
async_collectives->async_comm_stream()->ThenWaitFor(params.stream);
auto executed =
RunAllReduce(static_cast<ReductionKind>(reduction_kind), *device_buffers,
*async_collectives->async_comm_stream(), **comm);
if (!executed.ok()) return failure();
// Create an event on the async stream for the completion of the all-reduce.
se::Event done_event(async_collectives->async_comm_stream()->parent());
if (!done_event.Init()) return failure();
async_collectives->async_comm_stream()->ThenRecordEvent(&done_event);
if (failed(async_collectives->PushEvent(
uid, stream->parent()->device_ordinal(), std::move(done_event))))
return failure();
return success();
#else // XLA_ENABLE_XCCL
return failure(); // NCCL disabled.
#endif // XLA_ENABLE_XCCL
}
static bool AllReduceStart(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.all_reduce_start")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtAsyncCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<int64_t>("op_id")
.Attr<int64_t>("reduction_kind") // ReductionKind
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.Attr<int32_t>("uid")
.To<RuntimeChecks()>(AllReduceStart::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// ------------------------------------------------------------------------- //
namespace {
struct AllReduceDone {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
JitRtAsyncCollectiveSupport* async_collectives,
CustomCall::RemainingArgs args, int32_t uid) const;
static AllReduceDone Handler() { return AllReduceDone(); }
};
} // namespace
LogicalResult AllReduceDone::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
JitRtAsyncCollectiveSupport* async_collectives,
CustomCall::RemainingArgs args, int32_t uid) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running AllReduceDone";
se::Stream* stream = run_options->stream();
int32_t device_ordinal = stream->parent()->device_ordinal();
auto event = async_collectives->PopEvent(uid, device_ordinal);
if (failed(event)) return failure();
stream->ThenWaitFor(&*event);
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
return failure(); // NCCL disabled.
#endif // XLA_ENABLE_XCCL
}
static bool AllReduceDone(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.all_reduce_done")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.UserData<JitRtAsyncCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.To<RuntimeChecks()>(AllReduceDone::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct ReduceScatter {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
CustomCall::RemainingArgs args, int32_t uid,
int64_t group_mode, int64_t op_id,
int64_t reduction_kind,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const;
static ReduceScatter Handler() { return ReduceScatter(); }
};
} // namespace
LogicalResult ReduceScatter::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
int32_t uid, int64_t group_mode, int64_t op_id, int64_t reduction_kind,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running ReduceScatter";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
auto executed = RunReduceScatter(static_cast<ReductionKind>(reduction_kind),
*device_buffers, *stream, **comm);
if (!executed.ok()) return failure();
int32_t device_ordinal = stream->parent()->device_ordinal();
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
// NCCL disabled.
return failure();
#endif // XLA_ENABLE_XCCL
}
static bool ReduceScatter(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.reduce_scatter")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<int64_t>("op_id")
.Attr<int64_t>("reduction_kind") // ReductionKind
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.To<RuntimeChecks()>(ReduceScatter::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct AllGather {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
CustomCall::RemainingArgs args, int32_t uid,
int64_t group_mode, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const;
static AllGather Handler() { return AllGather(); }
};
} // namespace
LogicalResult AllGather::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
int32_t uid, int64_t group_mode, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running AllGather";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
if (!RunAllGather(*device_buffers, *stream, **comm).ok()) return failure();
int32_t device_ordinal = stream->parent()->device_ordinal();
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
// NCCL disabled.
return failure();
#endif // XLA_ENABLE_XCCL
}
static bool AllGather(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.all_gather")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<int64_t>("op_id")
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.To<RuntimeChecks()>(AllGather::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct AllToAll {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
CustomCall::RemainingArgs args, int32_t uid,
int64_t group_mode, bool has_split_dimension,
int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const;
static AllToAll Handler() { return AllToAll(); }
};
} // namespace
LogicalResult AllToAll::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
int32_t uid, int64_t group_mode, bool has_split_dimension, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running AllToAll";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
if (!RunAllToAll(has_split_dimension, *device_buffers, *stream, **comm).ok())
return failure();
int32_t device_ordinal = stream->parent()->device_ordinal();
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
// NCCL disabled.
return failure();
#endif // XLA_ENABLE_XCCL
}
static bool AllToAll(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.all_to_all")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<bool>("has_split_dimension")
.Attr<int64_t>("op_id")
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.To<RuntimeChecks()>(AllToAll::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct CollectivePermute {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives,
CustomCall::RemainingArgs args, int32_t uid,
int64_t group_mode, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values,
ArrayRef<int64_t> source_peers,
ArrayRef<int64_t> target_peers) const;
static CollectivePermute Handler() { return CollectivePermute(); }
};
} // namespace
LogicalResult CollectivePermute::operator()(
const ServiceExecutableRunOptions* run_options,
JitRtCollectiveSupport* collectives, CustomCall::RemainingArgs args,
int32_t uid, int64_t group_mode, int64_t op_id,
ArrayRef<int64_t> replica_group_offsets,
ArrayRef<int64_t> replica_group_values, ArrayRef<int64_t> source_peers,
ArrayRef<int64_t> target_peers) const {
#if XLA_ENABLE_XCCL
VLOG(3) << "Running CollectivePermute";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
auto comm = GetNcclComm(params, group_mode, op_id, replica_group_offsets,
replica_group_values);
if (failed(comm)) return comm;
auto device_buffers = GetDeviceBufferPairs(args);
if (failed(device_buffers)) return device_buffers;
if (device_buffers->size() != 1) return failure();
StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
if (!global_device_id.ok()) return failure();
StatusOr<DeviceAssignment::LogicalID> current_logical_id =
params.device_assn->LogicalIdForDevice(global_device_id.value());
if (!current_logical_id.ok()) return failure();
const int64_t current_id = static_cast<CollectiveOpGroupMode>(group_mode) ==
CollectiveOpGroupMode::kCrossReplica
? current_logical_id.value().replica_id
: current_logical_id.value().computation_id;
std::string device_string = NcclCollectiveThunk::GetDeviceString(params);
NcclCollectivePermuteConfig::IdToSourceTargetMap id_to_source_target;
for (int i = 0; i < source_peers.size(); ++i) {
id_to_source_target.insert({target_peers[i], {}}).first->second.source =
source_peers[i];
id_to_source_target.insert({source_peers[i], {}}).first->second.target =
target_peers[i];
}
const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
NcclCollectivePermuteConfig::GetSourceTarget(id_to_source_target,
current_id);
auto executed =
RunCollectivePermute(source_target, (*device_buffers)[0], *stream, **comm,
device_string, current_id);
if (!executed.ok()) return failure();
int32_t device_ordinal = stream->parent()->device_ordinal();
if (!collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, stream).ok())
return failure();
return success();
#else // XLA_ENABLE_XCCL
// NCCL disabled.
return failure();
#endif // XLA_ENABLE_XCCL
}
static bool CollectivePermute(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler =
CustomCall::Bind("xla.gpu.collective_permute")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<JitRtCollectiveSupport*>()
.RemainingArgs() // args
.Attr<int32_t>("uid")
.Attr<int64_t>("group_mode") // CollectiveOpGroupMode
.Attr<int64_t>("op_id")
.Attr<ArrayRef<int64_t>>("replica_group_offsets")
.Attr<ArrayRef<int64_t>>("replica_group_values")
.Attr<ArrayRef<int64_t>>("source_peers")
.Attr<ArrayRef<int64_t>>("target_peers")
.To<RuntimeChecks()>(CollectivePermute::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct ReplicaId {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView result) const;
static ReplicaId Handler() { return ReplicaId(); }
};
} // namespace
LogicalResult ReplicaId::operator()(
const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView result) const {
VLOG(3) << "Running ReplicaId";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
if (!global_device_id.ok()) return failure();
StatusOr<DeviceAssignment::LogicalID> logical_id =
params.device_assn->LogicalIdForDevice(global_device_id.value());
if (!logical_id.ok()) return failure();
se::DeviceMemoryBase result_data = GetDeviceAddress(result);
params.stream->ThenMemset32(&result_data, logical_id.value().replica_id,
/*size=*/4);
return success();
}
static bool ReplicaId(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.replica_id")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<jitrt::FlatMemrefView>() // result
.To<RuntimeChecks()>(ReplicaId::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
namespace {
struct PartitionId {
LLVM_ATTRIBUTE_ALWAYS_INLINE
LogicalResult operator()(const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView result) const;
static PartitionId Handler() { return PartitionId(); }
};
} // namespace
LogicalResult PartitionId::operator()(
const ServiceExecutableRunOptions* run_options,
jitrt::FlatMemrefView result) const {
VLOG(3) << "Running PartitionId";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
StatusOr<GlobalDeviceId> global_device_id = params.GetGlobalDeviceId();
if (!global_device_id.ok()) return failure();
StatusOr<DeviceAssignment::LogicalID> logical_id =
params.device_assn->LogicalIdForDevice(global_device_id.value());
if (!logical_id.ok()) return failure();
se::DeviceMemoryBase result_data = GetDeviceAddress(result);
params.stream->ThenMemset32(&result_data, logical_id.value().computation_id,
/*size=*/4);
return success();
}
static bool PartitionId(runtime::KernelContext* ctx, void** args,
void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.partition_id")
.UserData<const ServiceExecutableRunOptions*>()
.Arg<jitrt::FlatMemrefView>() // result
.To<RuntimeChecks()>(PartitionId::Handler())
.release();
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
// -------------------------------------------------------------------------- //
DirectCustomCallLibrary JitRtGpuCustomCalls() {
DirectCustomCallLibrary lib;
lib.Insert("xla.gpu.fft", &xla::gpu::Fft);
lib.Insert("xla.gpu.cholesky", &xla::gpu::Cholesky);
lib.Insert("xla.gpu.collective_permute", &xla::gpu::CollectivePermute);
lib.Insert("xla.gpu.func.launch", &xla::gpu::LaunchFunc);
lib.Insert("xla.gpu.gemm", &xla::gpu::Gemm);
auto conv = [](StringRef name) { return ("xla.gpu.conv." + name).str(); };
lib.Insert(conv("forward"), &ConvFn<CudnnConvKind::kForward>);
lib.Insert(conv("backward.input"), &ConvFn<CudnnConvKind::kBackwardInput>);
lib.Insert(conv("backward.filter"), &ConvFn<CudnnConvKind::kBackwardFilter>);
lib.Insert(conv("forward.fused"),
&ConvFusedFn<CudnnConvKind::kForwardActivation>);
lib.Insert(conv("forward.fused.side_input"),
&ConvFuseSideInputdFn<CudnnConvKind::kForwardActivation>);
lib.Insert("xla.gpu.memcpy.d2d", &MemcpyFn<MemcpyDirection::kDeviceToDevice>);
lib.Insert("xla.gpu.memcpy.h2d", &MemcpyFn<MemcpyDirection::kHostToDevice>);
lib.Insert("xla.gpu.memcpy.d2h", &MemcpyFn<MemcpyDirection::kDeviceToHost>);
lib.Insert("xla.gpu.memset", &MemsetFn);
lib.Insert("xla.gpu.infeed", &xla::gpu::Infeed);
lib.Insert("xla.gpu.outfeed", &xla::gpu::Outfeed);
lib.Insert("xla.gpu.custom_call", &xla::gpu::CustomCall);
// Collective operations.
lib.Insert("xla.gpu.all_gather", &xla::gpu::AllGather);
lib.Insert("xla.gpu.all_reduce", &xla::gpu::AllReduce);
lib.Insert("xla.gpu.all_reduce_done", &xla::gpu::AllReduceDone);
lib.Insert("xla.gpu.all_reduce_start", &xla::gpu::AllReduceStart);
lib.Insert("xla.gpu.all_to_all", &xla::gpu::AllToAll);
lib.Insert("xla.gpu.reduce_scatter", &xla::gpu::ReduceScatter);
lib.Insert("xla.gpu.partition_id", &xla::gpu::PartitionId);
lib.Insert("xla.gpu.replica_id", &xla::gpu::ReplicaId);
return lib;
}
} // namespace gpu
} // namespace xla