blob: ec261789a47ec7072e181d6e63ead9fdc6305ea5 [file] [log] [blame]
#include "THC.h"
#include "THCReduceApplyUtils.cuh"
#include "THCTensorCopy.h"
#include "THCTensorMath.h"
#include "THCAsmUtils.cuh"
#include "THCScanUtils.cuh"
#include "THCTensorTypeUtils.cuh"
#include <algorithm> // for std::min
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
// Converts a float to an integer representation with the same
// sorting; i.e., for floats f1, f2:
// if f1 < f2 then convert(f1) < convert(f2)
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
struct FloatToSortedInt {
inline __host__ __device__ FloatToSortedInt() {}
inline __device__ unsigned int convert(float v) const {
unsigned int x = __float_as_int(v);
unsigned int mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
}
inline __device__ float deconvert(unsigned int v) const {
unsigned int mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
// This function counts the distribution of all input values in a
// slice we are selecting by radix digit at `radixDigitPos`, but only
// those that pass the filter `((v & desiredMask) == desired)`.
// This produces and broadcasts the seen counts for a single block only.
// `smem` must have at least `RadixSize` elements.
template <typename DataType, typename BitDataType,
typename IndexType, typename CountType,
typename RadixConverter, int RadixSize, int RadixBits>
__device__ void countRadixUsingMask(const RadixConverter& conv,
CountType counts[RadixSize],
CountType* smem,
BitDataType desired,
BitDataType desiredMask,
int radixDigitPos,
IndexType sliceSize,
IndexType withinSliceStride,
DataType* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
smem[threadIdx.x] = 0;
}
__syncthreads();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (IndexType i = threadIdx.x; i < sliceSize; i += blockDim.x) {
BitDataType val = conv.convert(doLdg(&data[i * withinSliceStride]));
bool hasVal = ((val & desiredMask) == desired);
unsigned int digitInRadix = getBitfield(val, radixDigitPos, RadixBits);
#pragma unroll
for (unsigned int j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
counts[j] += __popc(__ballot(vote));
}
}
// Now, for each warp, sum values
if (getLaneId() == 0) {
#pragma unroll
for (unsigned int i = 0; i < RadixSize; ++i) {
atomicAdd(&smem[i], counts[i]);
}
}
__syncthreads();
// For each thread, read in the total counts
#pragma unroll
for (unsigned int i = 0; i < RadixSize; ++i) {
counts[i] = smem[i];
}
__syncthreads();
}
// Over what radix we are selecting values
#define RADIX_BITS 2 // digits are base-(2 ^ RADIX_BITS)
#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
#define RADIX_MASK (RADIX_SIZE - 1)
// This finds the unique value `v` that matches the pattern
// ((v & desired) == desiredMask) in our sorted int format
template <typename DataType, typename IndexType, typename RadixConverter>
__device__ float findPattern(const RadixConverter& conv,
DataType* smem,
DataType* data,
IndexType sliceSize,
IndexType withinSliceStride,
unsigned int desired,
unsigned int desiredMask) {
if (threadIdx.x < 32) {
smem[threadIdx.x] = (DataType) 0;
}
__syncthreads();
// All threads participate in the loop, in order to sync on the flag
IndexType numIterations = THCRoundUp(sliceSize, (IndexType) blockDim.x);
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < sliceSize);
DataType v = inRange ? doLdg(&data[i * withinSliceStride]) : (DataType) 0;
if (inRange && ((conv.convert(v) & desiredMask) == desired)) {
// There should not be conflicts if we are using findPattern,
// since the result is unique
smem[0] = (DataType) 1;
smem[1] = v; // can't use val as the flag, since it could be 0
}
__syncthreads();
DataType found = smem[0];
DataType val = smem[1];
__syncthreads();
// Check to see if a thread found the value
if (found != (DataType) 0) {
// all threads return this value
return val;
}
}
// should not get here
assert(false);
return (DataType) 0;
}
// Returns the top-Kth element found in the data using radix selection
template <typename DataType, typename BitDataType, typename IndexType,
typename RadixConverter, bool Order>
__device__ void radixSelect(const RadixConverter& conv,
DataType* data,
IndexType k,
IndexType sliceSize,
IndexType withinSliceStride,
int* smem,
DataType* topK) {
// Per-thread buckets into which we accumulate digit counts in our
// radix
int counts[RADIX_SIZE];
// We only consider elements x such that (x & desiredMask) == desired
// Initially, we consider all elements of the array, so the above
// statement is true regardless of input.
unsigned int desired = 0;
unsigned int desiredMask = 0;
// We are looking for the top kToFind-th element when iterating over
// digits; this count gets reduced by elimination when counting
// successive digits
int kToFind = k;
// We start at the most significant digit in our radix, scanning
// through to the least significant digit
#pragma unroll
for (int digitPos = sizeof(BitDataType) * 8 - RADIX_BITS;
digitPos >= 0;
digitPos -= RADIX_BITS) {
// Count radix distribution for the current position and reduce
// across all threads
countRadixUsingMask<DataType, BitDataType,
IndexType, int, RadixConverter,
RADIX_SIZE, RADIX_BITS>(
conv, counts, smem,
desired, desiredMask, digitPos,
sliceSize, withinSliceStride, data);
// All threads participate in the comparisons below to know the
// final result
#define CHECK_RADIX(i) \
int count = counts[i]; \
\
/* All threads have the same value in counts here, so all */ \
/* threads will return from the function. */ \
if (count == 1 && kToFind == 1) { \
/* There is a unique answer. */ \
desired = setBitfield(desired, i, digitPos, RADIX_BITS); \
desiredMask = \
setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
\
/* The answer is now the unique element v such that: */ \
/* (v & desiredMask) == desired */ \
/* However, we do not yet know what the actual element is. We */ \
/* need to perform a search through the data to find the */ \
/* element that matches this pattern. */ \
*topK = findPattern<DataType, IndexType, RadixConverter>( \
conv, (float*) smem, data, sliceSize, \
withinSliceStride, desired, desiredMask); \
return; \
} \
\
if (count >= kToFind) { \
desired = setBitfield(desired, i, digitPos, RADIX_BITS); \
desiredMask = \
setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
\
/* The top-Kth element v must now be one such that: */ \
/* (v & desiredMask == desired) */ \
/* but we haven't narrowed it down; we must check the next */ \
/* least-significant digit */ \
break; \
} \
\
kToFind -= count \
if (Order) {
// Process in descending order
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
CHECK_RADIX(i);
}
} else {
// Process in ascending order
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
CHECK_RADIX(i);
}
}
#undef CHECK_RADIX
} // end digitPos for
// There is no unique result, but there is a non-unique result
// matching `desired` exactly
*topK = conv.deconvert(desired);
}
template <typename IndexType, int Dim, bool Order>
__global__ void gatherTopK(TensorInfo<float, IndexType> input,
IndexType inputSliceSize,
IndexType outputSliceSize, // aka `k`
IndexType numInputSlices,
IndexType inputWithinSliceStride,
TensorInfo<float, IndexType> topK,
IndexType numTopKSlices,
IndexType topKWithinSliceStride,
TensorInfo<long, IndexType> indices,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
__shared__ int smem[32]; // one per each warp, up to warp limit
IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
IndexType sliceStartIndex =
IndexToOffset<float, IndexType, Dim>::get(slice, input);
IndexType topKSliceStartIndex =
IndexToOffset<float, IndexType, Dim>::get(slice, topK);
IndexType indicesSliceStartIndex =
IndexToOffset<long, IndexType, Dim>::get(slice, indices);
float* inputSliceStart = &input.data[sliceStartIndex];
float* topKSliceStart = &topK.data[topKSliceStartIndex];
long* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
float topKValue = -1.0f;
radixSelect<float, unsigned int, IndexType, FloatToSortedInt, Order>(
FloatToSortedInt(),
inputSliceStart, outputSliceSize,
inputSliceSize, inputWithinSliceStride,
smem, &topKValue);
// Every value that is strictly less/greater than `pattern`
// (depending on sort dir) in sorted int format is in the top-K.
// The top-K value itself might not be unique.
//
// Since there are a variable number of elements that we see that
// are within the top-k, we don't know at what index to write out
// the resulting values.
// In order to get this, we perform an exclusive prefix sum of
// `hasTopK`. This will return the resulting index into which we
// need to write the result, if a thread has a result.
// All threads need to participate in the loop and the prefix sum,
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
IndexType writeIndexStart = 0;
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
float v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
bool hasTopK;
if (Order) {
hasTopK = inRange && (v > topKValue);
} else {
hasTopK = inRange && (v < topKValue);
}
int index;
int carry;
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
if (hasTopK) {
int writeIndex = writeIndexStart + index;
assert(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
}
writeIndexStart += carry;
}
// We need to fill in the rest with actual == top-K values.
// The number that we need is outputSliceSize -
// writeIndexStart. There might be more than that number available,
// in which case we have to choose the first seen set. We do this
// via a prefix sum to calculate indices for writing results.
assert(outputSliceSize >= writeIndexStart);
IndexType topKRemaining = (outputSliceSize - writeIndexStart);
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
float v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
bool hasTopK = inRange && (v == topKValue);
int index;
int carry;
exclusiveBinaryPrefixSum<int, true>(smem, hasTopK, &index, &carry);
if (hasTopK && index < topKRemaining) {
int writeIndex = writeIndexStart + index;
assert(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
}
if (carry >= topKRemaining) {
break;
}
topKRemaining -= carry;
writeIndexStart += carry;
}
}
#undef RADIX_BITS
#undef RADIX_SIZE
#undef RADIX_MASK
THC_API void THCudaTensor_topk(THCState* state,
THCudaTensor *topK,
THCudaLongTensor *indices,
THCudaTensor *input,
long k, int dim, int dir, int sorted) {
THAssert(topK != NULL && indices != NULL && input != NULL);
THAssert(THCudaTensor_checkGPU(state, 3, topK, indices, input));
THCCheckTensorDims(state, topK, 2);
long dims = THCudaLongTensor_nDimension(state, indices);
THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
THCCheckTensorDims(state, input, 2);
int numDims = THCudaTensor_nDimension(state, input);
THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range");
long sliceSize = THCudaTensor_size(state, input, dim);
THArgCheck(k > 0 && k <= sliceSize, 2, "k not in range for dimension");
// Build the output size, which is the dim being selected set to
// size k
THLongStorage* topKSize = THCudaTensor_newSizeOf(state, input);
THLongStorage_set(topKSize, dim, k);
THCudaTensor_resize(state, topK, topKSize, NULL);
THCudaLongTensor_resize(state, indices, topKSize, NULL);
THLongStorage_free(topKSize);
#define RUN_K(INDEX_T, DIM, DIR) \
gatherTopK<INDEX_T, DIM, DIR> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
inputInfo, \
sliceSize, \
k, \
inputSlices, \
/* The actual dimension that the k-selection is running in */ \
/* may have changed from collapseDims() */ \
inputInfo.strides[collapseInputDim], \
topKInfo, \
topKSlices, \
topKInfo.strides[collapseTopKDim], \
indicesInfo, \
indicesInfo.strides[collapseIndicesDim])
#define RUN_DIR(INDEX_T, DIM) \
if (dir) { \
RUN_K(INDEX_T, DIM, true); \
} else { \
RUN_K(INDEX_T, DIM, false); \
}
#define RUN_DIM(INDEX_T) \
if (allDims == 1) { \
RUN_DIR(INDEX_T, 1); \
} else if (allDims == 2) { \
RUN_DIR(INDEX_T, 2); \
} else if (allDims == 3) { \
RUN_DIR(INDEX_T, 3); \
} else { \
RUN_DIR(INDEX_T, -1); \
}
#define RUN_T(INDEX_T) \
TensorInfo<float, INDEX_T> inputInfo = \
getTensorInfo<THCudaTensor, INDEX_T>(state, input); \
TensorInfo<float, INDEX_T> topKInfo = \
getTensorInfo<THCudaTensor, INDEX_T>(state, topK); \
TensorInfo<long, INDEX_T> indicesInfo = \
getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices); \
\
/* We use these structures solely to find the offset to */ \
/* each slice we are operating on */ \
inputInfo.sizes[dim] = 1; \
topKInfo.sizes[dim] = 1; \
indicesInfo.sizes[dim] = 1; \
\
/* Collapse all other dims */ \
int collapseInputDim = inputInfo.collapseDims(dim); \
int collapseTopKDim = topKInfo.collapseDims(dim); \
int collapseIndicesDim = indicesInfo.collapseDims(dim); \
\
long inputSlices = 1; \
long topKSlices = 1; \
for (int i = 0; i < numDims; ++i) { \
inputSlices *= inputInfo.sizes[i]; \
topKSlices *= topKInfo.sizes[i]; \
} \
\
dim3 grid; \
if (!THC_getGridFromTiles(inputSlices, grid)) { \
THError("Slice to sort is too large"); \
} \
\
dim3 block(std::min(THCRoundUp(sliceSize, 32L), 1024L)); \
\
/* This is used as a template parameter to calculate indices. */ \
/* We only specialize it if all collapsed dim sizes are the */ \
/* same; otherwise, we use -1 which is the specialization */ \
/* parameter for arbitrary dimensions */ \
int allDims = inputInfo.dims; \
if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \
allDims = -1; \
} \
\
RUN_DIM(INDEX_T);
// Based on required index size, run the algorithm with the
// appropriate index type
if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input) &&
TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, topK) &&
TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, indices)) {
RUN_T(unsigned int);
} else {
RUN_T(unsigned long);
}
#undef RUN_T
#undef RUN_DIM
#undef RUN_DIR
#undef RUN_K
// Sort the results if the user wants them sorted, since our
// selection routine does not ensure sorting
if (sorted) {
// FIXME: the k/v inplace sort along slice only works for size <=
// 2048 at the moment
if (sliceSize <= 2048) {
// This avoids any memory allocations and performs all sorting
// work inplace along the slice
THCudaTensor_sortKeyValueInplace(state, topK, indices, dim, dir);
} else {
// Depend upon the backup sort that returns indices, which we
// can use in conjunction with gather to produce the original
// indices.
// This is not the most efficient implementation, especially since
// there are memory allocations performed here. If the user desires
// greater performance, they should torch.gather() the results
// themselves using the reported indices, providing previously
// allocated tensors to receive the results.
THCudaTensor* sortedTopK = THCudaTensor_new(state);
THCudaLongTensor* sortedIndices = THCudaLongTensor_new(state);
THCudaTensor_sort(state, sortedTopK, sortedIndices, topK, dim, dir);
THCudaLongTensor* sortedTopKIndices = THCudaLongTensor_new(state);
THCudaLongTensor_resizeAs(state, sortedTopKIndices, indices);
THCudaLongTensor_gather(state, sortedTopKIndices, indices, dim, sortedIndices);
THCudaTensor_freeCopyTo(state, sortedTopK, topK);
THCudaLongTensor_freeCopyTo(state, sortedTopKIndices, indices);
THCudaLongTensor_free(state, sortedIndices);
}
}
THCudaCheck(cudaGetLastError());
}