blob: 1c70c39f33e22a7fdb63e324dccc2139c8a98078 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/top_k.h"
#include <thrust/sort.h>
#include <thrust/system/cuda/execution_policy.h>
#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/top_k_heap_selection.cuh"
#include "caffe2/operators/top_k_radix_selection.cuh"
namespace caffe2 {
// Converts a matrix of size [outerSize, k] containing
// row-wise indices into global (linearized) indices from an original
// matrix of [outerSize, innerSize]
template <typename Index>
__global__ void linearizeRowIndices(
Index* in,
Index* out,
int outerSize,
int innerSize,
int k) {
if (blockIdx.x < outerSize) {
in += (Index)blockIdx.x * k;
out += (Index)blockIdx.x * k;
auto indexOffset = (Index)blockIdx.x * innerSize;
for (int i = threadIdx.x; i < k; i += blockDim.x) {
out[i] = in[i] + indexOffset;
}
}
}
template <>
class TopKOp<float, CUDAContext> : public Operator<CUDAContext> {
public:
TopKOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
OP_SINGLE_ARG(int, "k", k_, -1) {}
bool RunOnDevice() override;
private:
int k_;
};
bool TopKOp<float, CUDAContext>::RunOnDevice() {
auto& input = Input(0);
auto* values = Output(0);
auto* indices = Output(1);
auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
vector<TIndex> in_dims = input.dims();
CAFFE_ENFORCE(
in_dims.back() >= k_, "k argment should not be greater than last dim");
vector<TIndex> out_dims = in_dims;
out_dims.back() = k_;
// Get the batch size
size_t outerSize = 1;
for (int i = 0; i < in_dims.size() - 1; ++i) {
outerSize *= in_dims[i];
}
values->Resize(out_dims);
indices->Resize(out_dims);
if (flatten_indices) {
flatten_indices->Resize(outerSize * k_);
}
// Right now, the top-k operator only supports max-k
constexpr bool kDir = true;
if (k_ <= 512) {
// heap selection is possible
constexpr int kBlockSize = 256;
int numWarps = kBlockSize / kWarpSize;
auto grid = outerSize;
auto block = kBlockSize;
#define RUN_HEAP(HEAP_SIZE) \
do { \
int smem = numWarps * HEAP_SIZE * (sizeof(float) + sizeof(TIndex)); \
\
selectRowsViaHeap<float, TIndex, TIndex, kBlockSize, HEAP_SIZE, kDir> \
<<<grid, block, smem, context_.cuda_stream()>>>( \
input.data<float>(), \
values->mutable_data<float>(), \
indices->mutable_data<TIndex>(), \
kDir ? -std::numeric_limits<float>::infinity() \
: std::numeric_limits<float>::infinity(), \
kDir ? -std::numeric_limits<TIndex>::max() \
: std::numeric_limits<float>::max(), \
outerSize, \
in_dims.back(), \
k_); \
} while (false)
if (k_ <= 32) {
RUN_HEAP(32);
} else if (k_ <= 128) {
RUN_HEAP(128);
} else {
RUN_HEAP(512);
}
#undef RUN_HEAP
} else {
// k is too large, use radix selection instead
auto grid = outerSize;
auto block = std::min(
math::roundUp((int)in_dims.back(), kWarpSize), CAFFE_CUDA_NUM_THREADS);
// Radix selection required
gatherTopK<float, kDir, TIndex><<<grid, block, 0, context_.cuda_stream()>>>(
input.data<float>(),
in_dims.back(),
k_,
outerSize,
values->mutable_data<float>(),
indices->mutable_data<TIndex>());
// Unfortunately the output is not currently sorted, and there is
// no batch sorting utility available. Iterate over all of the
// slices and sort them in-place using Thrust.
for (int slice = 0; slice < outerSize; ++slice) {
thrust::sort_by_key(
thrust::cuda::par.on(context_.cuda_stream()),
values->mutable_data<float>() + slice * k_,
values->mutable_data<float>() + slice * k_ + k_,
indices->mutable_data<TIndex>() + slice * k_,
thrust::greater<float>());
}
}
// Now that we've completed writing the indices, linearize the
// indices if we need it
if (flatten_indices) {
linearizeRowIndices<TIndex>
<<<outerSize, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
indices->mutable_data<TIndex>(),
flatten_indices->mutable_data<TIndex>(),
outerSize,
in_dims.back(),
k_);
}
return true;
}
REGISTER_CUDA_OPERATOR(TopK, TopKOp<float, CUDAContext>);
__global__ void fillValuesWithIndicesKernel(
const float* values,
const TIndex* indices,
const TIndex k,
const TIndex orignal_last_dim,
const TIndex length,
float* output) {
CUDA_1D_KERNEL_LOOP(i, length) {
int first_dim = i / k;
int idx = orignal_last_dim * first_dim + indices[i];
output[idx] = values[i];
}
}
template <>
bool TopKGradientOp<float, CUDAContext>::RunOnDevice() {
auto& values = Input(0);
auto& indices = Input(1);
auto& original_input = Input(2);
vector<TIndex> in_dims = values.dims();
// Linearize input tensor except for last dimension
// e.g. [3, 4, 5] -> [12, 5]
// [5] -> [5]
TIndex flatten_shape[] = {size_to_dim_(in_dims.size() - 1, in_dims),
in_dims[in_dims.size() - 1]};
vector<TIndex> original_dims = original_input.dims();
auto* output = Output(0);
output->Resize(original_dims);
float* output_data = output->mutable_data<float>();
math::Set<float, CUDAContext>(
output->size(), float(0), output_data, &context_);
int length = flatten_shape[0] * flatten_shape[1];
if (length == 0) { // for empty batch
return true;
}
int num_threads = std::min(CAFFE_CUDA_NUM_THREADS, length);
int blocks = math::divUp(length, num_threads);
fillValuesWithIndicesKernel<<<
blocks,
num_threads,
0,
context_.cuda_stream()>>>(
values.data<float>(),
indices.data<TIndex>(),
flatten_shape[1],
original_dims.back(),
length,
output_data);
return true;
}
REGISTER_CUDA_OPERATOR(TopKGradient, TopKGradientOp<float, CUDAContext>);
} // namespace caffe2