blob: 0cc4cee7ce814d03281fab8ba9590c346e7fa028 [file] [log] [blame]
#include "THCReduceApplyUtils.cuh"
#include "THCTensorCopy.h"
#include <thrust/device_ptr.h>
#include <curand_kernel.h> // known CUDA 6 compilation error fix
#include <thrust/sort.h>
template <typename IndexType, int Power2SortSize>
__device__ __forceinline__ IndexType
getSortSliceLinearIndex() {
// linear block ID -> slice we are sorting (one per block)
return getLinearBlockId<IndexType>();
}
// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
unsigned long nextHighestPowerOf2(unsigned long n) {
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n |= n >> 32;
n++;
return n;
}
template <typename T>
struct LTComp {
__device__ __forceinline__ bool operator()(const T& a, const T& b) const {
return (a < b);
}
};
template <typename T>
struct GTComp {
__device__ __forceinline__ bool operator()(const T& a, const T& b) const {
return (a > b);
}
};
template <typename Comparator, typename K, typename V>
__device__ __forceinline__ void bitonicSwap(K& kA, V& vA,
K& kB, V& vB,
bool dir,
const Comparator& comp) {
// Entries with -1 indices (not real data; out of bounds) always
// sort to the end
bool val = (comp(kA, kB) && (vA != -1)) || (vB == -1);
if (val == dir) {
K k = kA;
kA = kB;
kB = k;
V v = vA;
vA = vB;
vB = v;
}
};
template <typename Comparator, typename K, typename V,
typename IndexType, int Power2SortSize>
__device__ inline void bitonicSort(K keys[Power2SortSize],
V values[Power2SortSize],
const Comparator& comp) {
#pragma unroll
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#pragma unroll
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
// Single warp per slice is completely synchronous
if (Power2SortSize > 64) {
__syncthreads();
}
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], keys[pos + stride], values[pos + stride],
flag, comp);
}
}
#pragma unroll
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
// Single warp per slice is completely synchronous
if (Power2SortSize > 64) {
__syncthreads();
}
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], keys[pos + stride], values [pos + stride],
false, comp);
}
// Single warp per slice is completely synchronous
if (Power2SortSize > 64) {
__syncthreads();
}
}
template <typename Comparator, typename IndexType, int Dims, int Power2SortSize>
__global__ void
THCudaTensor_bitonicSortWithIndex(TensorInfo<IndexType> sorted,
TensorInfo<IndexType> indices,
TensorInfo<IndexType> input,
int dim,
IndexType totalSlices,
IndexType sliceSize,
IndexType sliceStride,
IndexType outSize,
IndexType outStride,
const Comparator comp) {
// Find the slice of the tensor that we are sorting
const IndexType linearIndex =
getSortSliceLinearIndex<IndexType, Power2SortSize>();
// Tiling the slices could have us be out of bounds, if there are a
// lot of slices to sort
if (linearIndex >= totalSlices) {
return;
}
__shared__ float keys[Power2SortSize];
__shared__ int values[Power2SortSize];
// Read unsorted values
const IndexType inputStartOffset =
IndexToOffset<IndexType, Dims>::get(linearIndex, input);
// Each thread is responsible for loading and storing 2 elements
const int elem1 = threadIdx.x;
const int elem2 = threadIdx.x + (Power2SortSize / 2);
keys[elem1] = (elem1 < sliceSize) ?
input.data[inputStartOffset + elem1 * sliceStride] :
0.0f; // doesn't matter, element val out of bounds
// Torch indices are 1-based (hence the +1)
values[elem1] = (elem1 < sliceSize) ? (elem1 + 1) :
-1; // out of bounds
keys[elem2] = (elem2 < sliceSize) ?
input.data[inputStartOffset + elem2 * sliceStride] :
0.0f; // doesn't matter, element val out of bounds
// Torch indices are 1-based (hence the +1)
values[elem2] = (elem2 < sliceSize) ? (elem2 + 1) :
-1; // out of bounds
// Sort!
bitonicSort<Comparator, float, int, IndexType, Power2SortSize>(
keys, values, comp);
// Write sorted values; indices have same layout
const IndexType sortedStartOffset =
IndexToOffset<IndexType, -1>::get(linearIndex, sorted);
const IndexType out1 = sortedStartOffset + elem1 * outStride;
// elem1 values are always valid, since otherwise we would have
// chosen the next smallest power-of-2 for sorting
sorted.data[out1] = keys[elem1];
indices.data[out1] = values[elem1];
const IndexType out2 = sortedStartOffset + elem2 * outStride;
// elem2 values might be out-of-range, if the data size we are
// sorting is not a power-of-2
if (values[elem2] != -1) {
sorted.data[out2] = keys[elem2];
indices.data[out2] = values[elem2];
}
}
bool THCudaTensor_sortImpl(THCState* state,
THCudaTensor* sorted,
THCudaTensor* indices,
THCudaTensor* input,
int dim, bool dir) {
long inElements = THCudaTensor_nElement(state, input);
long sliceSize = THCudaTensor_size(state, input, dim);
long sliceStride = THCudaTensor_stride(state, input, dim);
long slices = inElements / sliceSize;
long outSize = THCudaTensor_size(state, sorted, dim);
long outStride = THCudaTensor_stride(state, sorted, dim);
if (THCudaTensor_nDimension(state, input) > MAX_CUTORCH_DIMS) {
// Too many dimensions
return false;
}
if (THCudaTensor_nDimension(state, input) == 0) {
// Zero-dim tensor; do nothing
return true;
}
// The amount of shared memory and block size is based on
// 2^ceil(lg(n)); we choose that sorting implementation for a given
// size.
long ceilPowerOf2 = nextHighestPowerOf2(sliceSize);
// Only handle 1-2048 at the moment
if (ceilPowerOf2 > 2048) {
return false;
}
const dim3 block(ceilPowerOf2 / 2);
// The grid is based on the number of independent slices that we
// have to sort; one block per slice
dim3 grid;
if (!THC_getGridFromTiles(slices, grid)) {
return false;
}
#define HANDLE_CASE(TYPE, A, SIZE) \
if (dir) { \
THCudaTensor_bitonicSortWithIndex<GTComp<float>, TYPE, A, SIZE> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
sortedInfo, indicesInfo, inputInfo, \
dim, \
slices, (TYPE) sliceSize, (TYPE) sliceStride, \
(TYPE) outSize, (TYPE) outStride, \
GTComp<float>()); \
} else { \
THCudaTensor_bitonicSortWithIndex<LTComp<float>, TYPE, A, SIZE> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
sortedInfo, indicesInfo, inputInfo, \
dim, \
slices, (TYPE) sliceSize, (TYPE) sliceStride, \
(TYPE) outSize, (TYPE) outStride, \
LTComp<float>()); \
}
#define HANDLE_SORT_CASE(TYPE, A) \
{ \
switch (ceilPowerOf2) { \
case 2048: \
HANDLE_CASE(TYPE, A, 2048); \
break; \
case 1024: \
HANDLE_CASE(TYPE, A, 1024); \
break; \
case 512: \
HANDLE_CASE(TYPE, A, 512); \
break; \
case 256: \
HANDLE_CASE(TYPE, A, 256); \
break; \
case 128: \
HANDLE_CASE(TYPE, A, 128); \
break; \
case 64: \
HANDLE_CASE(TYPE, A, 64); \
break; \
case 32: \
HANDLE_CASE(TYPE, A, 32); \
break; \
case 16: \
HANDLE_CASE(TYPE, A, 16); \
break; \
case 8: \
HANDLE_CASE(TYPE, A, 8); \
break; \
case 4: \
HANDLE_CASE(TYPE, A, 4); \
break; \
case 2: \
HANDLE_CASE(TYPE, A, 2); \
break; \
case 1: \
HANDLE_CASE(TYPE, A, 1); \
break; \
default: \
assert(false); \
} \
}
#define HANDLE_A_CASE(TYPE, A) \
{ \
if (inputInfo.isContiguous()) { \
HANDLE_SORT_CASE(TYPE, -2); \
} else { \
switch (A) { \
case 1: \
HANDLE_SORT_CASE(TYPE, 1); \
break; \
case 2: \
HANDLE_SORT_CASE(TYPE, 2); \
break; \
case 3: \
HANDLE_SORT_CASE(TYPE, 3); \
break; \
default: \
HANDLE_SORT_CASE(TYPE, -1); \
break; \
} \
} \
}
if (THC_canUse32BitIndexMath(state, input)) {
// In order to get to the right offset for the slice we are
// sorting, set `dim` size to 1 (the `dropDim` argument)
TensorInfo<unsigned int> sortedInfo(state, sorted, dim);
TensorInfo<unsigned int> indicesInfo(state, indices, dim);
TensorInfo<unsigned int> inputInfo(state, input, dim);
HANDLE_A_CASE(unsigned int, inputInfo.dims);
} else {
// In order to get to the right offset for the slice we are
// sorting, set `dim` size to 1 (the `dropDim` argument)
TensorInfo<unsigned long> sortedInfo(state, sorted, dim);
TensorInfo<unsigned long> indicesInfo(state, indices, dim);
TensorInfo<unsigned long> inputInfo(state, input, dim);
// long case is rare, just instantiate these versions
if (inputInfo.isContiguous()) {
HANDLE_SORT_CASE(unsigned long, -2);
} else {
HANDLE_SORT_CASE(unsigned long, -1);
}
}
#undef HANDLE_CASE
#undef HANDLE_SORT_CASE
#undef HANDLE_A_CASE
return true;
}
// `base` is the base address of a tensor
// For each slice (defined as a linear point of `out`, from 0 ->
// (sliceSize - 1) * sliceStride, we fill that slice from `0` to
// `sliceSize - 1`.
__global__ void
THCudaTensor_fillSliceWithIndex(TensorInfo<unsigned long> out,
long totalSlices,
long sliceSize,
long sliceStride) {
long slice = getLinearBlockId<long>();
if (slice >= totalSlices) {
return;
}
const unsigned long offset =
IndexToOffset<unsigned long, -1>::get(slice, out);
for (long i = threadIdx.x; i < sliceSize; i += blockDim.x) {
// Torch indices are 1-based (hence the +1)
out.data[offset + i * sliceStride] = (float) i + 1;
}
}
bool canSortThrust(THCState* state, THCudaTensor* input, int dim) {
long totalElements = THCudaTensor_nElement(state, input);
long sliceSize = THCudaTensor_size(state, input, dim);
long numSlices = totalElements / sliceSize;
// Only bother deferring to Thrust if the sort slice is contiguous,
// the number of slices are small, and they are large
return ((THCudaTensor_stride(state, input, dim) == 1) &&
numSlices <= 16 &&
sliceSize > 4096);
}
void THCudaTensor_sortImplThrust(THCState* state,
THCudaTensor* sorted,
THCudaTensor* indices,
THCudaTensor* input,
int dim, bool dir) {
// Fill the indices as values that Thrust can use for key/value sorting
long totalElements = THCudaTensor_nElement(state, input);
long sliceSize = THCudaTensor_size(state, input, dim);
long sliceStride = THCudaTensor_stride(state, input, dim);
long numSlices = totalElements / sliceSize;
// Copy input to sorted, since we sort in place
if (sorted != input) {
THCudaTensor_copy(state, sorted, input);
}
TensorInfo<unsigned long> sortedInfo(state, sorted, dim);
TensorInfo<unsigned long> indicesInfo(state, indices, dim);
dim3 grid;
THC_getGridFromTiles(numSlices, grid);
THCudaTensor_fillSliceWithIndex<<<grid, min((long long)sliceSize, 1024LL),
0, THCState_getCurrentStream(state)>>>(
indicesInfo, numSlices, sliceSize, sliceStride);
THCudaCheck(cudaGetLastError());
for (long slice = 0; slice < numSlices; ++slice) {
unsigned long sortedStart =
IndexToOffset<unsigned long, -1>::get(slice, sortedInfo);
unsigned long indicesStart =
IndexToOffset<unsigned long, -1>::get(slice, indicesInfo);
thrust::device_ptr<float>
sortedSliceStart(THCudaTensor_data(state, sorted) +
sortedStart);
thrust::device_ptr<float>
sortedSliceEnd(THCudaTensor_data(state, sorted) +
sortedStart + sliceSize);
thrust::device_ptr<float>
indicesSliceStart(THCudaTensor_data(state, indices) +
indicesStart);
if (dir) {
thrust::sort_by_key(sortedSliceStart, sortedSliceEnd, indicesSliceStart,
thrust::greater<float>());
} else {
thrust::sort_by_key(sortedSliceStart, sortedSliceEnd, indicesSliceStart,
thrust::less<float>());
}
}
}
THC_API void THCudaTensor_sort(THCState* state,
THCudaTensor *sorted,
THCudaTensor *indices,
THCudaTensor *input,
int dim, int order) {
THAssert(THCudaTensor_checkGPU(state, 3, sorted, indices, input));
// Make sure sufficient output space is allocated
THCudaTensor_resizeAs(state, sorted, input);
THCudaTensor_resizeAs(state, indices, input);
// If we think Thrust will be more efficient, use that
if (canSortThrust(state, input, dim)) {
THCudaTensor_sortImplThrust(state, sorted, indices, input,
dim, (bool) order);
return;
}
// Otherwise, use our blockwide sort kernel per each reduction slice
if (!THCudaTensor_sortImpl(state, sorted, indices, input,
dim, (bool) order)) {
THArgCheck(false, 1, CUTORCH_DIM_WARNING);
}
THCudaCheck(cudaGetLastError());
}