blob: ad338e4b77d7b9bdabe039250760ba113dc92330 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
#include <string>
#include "absl/base/call_once.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
#if TENSORFLOW_USE_ROCM
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
namespace stream_executor {
namespace gpu {
extern void rocm_MakeBatchPointers(void* stream, char* base, int stride, int n, void** ptrs_out);
}}
#endif
namespace xla {
namespace gpu {
namespace {
// GPU kernel to populate an array of pointers:
//
// [base + stride * i for i in range(n)].
//
// Generated from the following CUDA code.
//
// extern "C" {
// __global__ void __xla_MakeBatchPointers(char* base, int stride,
// int n, void** ptrs_out) {
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
// if (idx >= n) return;
// ptrs_out[idx] = base + idx * stride;
// }
// }
constexpr const char* kMakeBatchPointersPtx = R"(
.version 4.2
.target sm_35
.address_size 64
.visible .entry __xla_MakeBatchPointers(
.param .u64 __xla_MakeBatchPointers_param_0,
.param .u32 __xla_MakeBatchPointers_param_1,
.param .u32 __xla_MakeBatchPointers_param_2,
.param .u64 __xla_MakeBatchPointers_param_3
)
{
.reg .pred %p<2>;
.reg .b32 %r<8>;
.reg .b64 %rd<8>;
ld.param.u32 %r2, [__xla_MakeBatchPointers_param_2];
mov.u32 %r3, %tid.x;
mov.u32 %r4, %ctaid.x;
mov.u32 %r5, %ntid.x;
mad.lo.s32 %r6, %r4, %r5, %r3;
setp.ge.s32 %p1, %r6, %r2;
@%p1 bra LBB0_2;
ld.param.u64 %rd3, [__xla_MakeBatchPointers_param_0];
ld.param.u64 %rd4, [__xla_MakeBatchPointers_param_3];
cvta.to.global.u64 %rd5, %rd4;
ld.param.u32 %r1, [__xla_MakeBatchPointers_param_1];
mul.wide.s32 %rd6, %r6, 8;
add.s64 %rd1, %rd5, %rd6;
mul.lo.s32 %r7, %r6, %r1;
cvt.s64.s32 %rd7, %r7;
add.s64 %rd2, %rd3, %rd7;
st.global.u64 [%rd1], %rd2;
LBB0_2:
ret;
}
)";
// Lazily compiles ptx kernel, once per StreamExecutor.
//
// Thread-safe.
template <typename... KernelArgs>
class LazyKernel {
public:
LazyKernel(absl::string_view kernel_name, const char* ptx,
const se::GpuAsmOpts& asm_opts)
: kernel_name_(kernel_name), ptx_(ptx), asm_opts_(asm_opts) {}
StatusOr<se::TypedKernel<KernelArgs...>*> Get(
se::StreamExecutor* stream_exec) {
absl::MutexLock lock(&mu_);
auto result = kernels_.emplace(stream_exec, nullptr);
if (result.second) {
absl::Span<const uint8_t> compiled_ptx;
StatusOr<absl::Span<const uint8_t>> compiled_ptx_or =
se::CompileGpuAsmOrGetCached(stream_exec->device_ordinal(), ptx_,
asm_opts_);
if (compiled_ptx_or.ok()) {
compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
} else {
static absl::once_flag logged_once;
absl::call_once(logged_once, [&]() {
LOG(WARNING)
<< compiled_ptx_or.status().ToString()
<< "\nRelying on driver to perform ptx compilation. "
<< "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
<< " or modifying $PATH can be used to set the location of ptxas."
<< "\nThis message will only be logged once.";
});
}
auto kernel = stream_exec->CreateTypedKernel<KernelArgs...>(
kernel_name_, ptx_, compiled_ptx);
if (kernel.ok()) {
result.first->second = *std::move(kernel);
} else {
kernels_.erase(result.first);
return kernel.status();
}
}
return result.first->second.get();
}
private:
std::string kernel_name_;
const char* ptx_;
se::GpuAsmOpts asm_opts_;
absl::Mutex mu_;
// A mutex keyed on StreamExecutor* is ok because StreamExecutors are never
// destroyed.
absl::flat_hash_map<se::StreamExecutor*,
std::unique_ptr<se::TypedKernel<KernelArgs...>>>
kernels_ ABSL_GUARDED_BY(mu_);
};
} // anonymous namespace
Status MakeBatchPointers(se::Stream* stream, const se::GpuAsmOpts& asm_opts,
se::DeviceMemoryBase base_ptr, int stride_bytes, int n,
se::DeviceMemoryBase ptrs_out) {
#if TENSORFLOW_USE_ROCM
stream_executor::gpu::rocm_MakeBatchPointers(
se::gpu::AsGpuStreamValue(stream),
reinterpret_cast<char*>(base_ptr.opaque()),
stride_bytes, n,
reinterpret_cast<void**>(ptrs_out.opaque()));
#else
static auto* lazy_kernel =
new LazyKernel<se::DeviceMemoryBase /*base_ptr*/, int /*stride_bytes*/,
int /*n*/, se::DeviceMemoryBase /*ptrs_out*/>(
"__xla_MakeBatchPointers", kMakeBatchPointersPtx, asm_opts);
TF_ASSIGN_OR_RETURN(auto kernel, lazy_kernel->Get(stream->parent()));
constexpr int kThreads = 128;
stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1),
se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), *kernel,
base_ptr, stride_bytes, n, ptrs_out);
#endif
return Status::OK();
}
} // namespace gpu
} // namespace xla