Generic TopK implementation (#744)
* move TopK to generic
* partial genericization of kernel code
* introduce TopKTypeConfig, specialize radix type and conversion for floats
* implement topk for byte tensor
* implement for char tensor
* implement for int tensor, extend test to check indices as well
* works for longs too
* make bitfield set/get a struct, add support for 64-bit types
* extend to double tensor
* implement for half tensor
* asserts; test fix
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 935098a..1ea6039 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -258,7 +258,6 @@
THCTensorRandom.h
THCTensorMath.h
THCTensorConv.h
- THCTensorTopK.h
THCApply.cuh
THCReduce.cuh
THCReduceAll.cuh
@@ -295,6 +294,7 @@
THCTensorMathMagma.cuh
THCThrustAllocator.cuh
THCTensorMode.cuh
+ THCTensorTopK.cuh
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC")
INSTALL(FILES
@@ -341,4 +341,6 @@
generic/THCTensorRandom.cu
generic/THCTensorMode.h
generic/THCTensorMode.cu
+ generic/THCTensorTopK.h
+ generic/THCTensorTopK.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
diff --git a/THC.h b/THC.h
index e3840dc..90a3a53 100644
--- a/THC.h
+++ b/THC.h
@@ -15,6 +15,5 @@
#include "THCTensorRandom.h"
#include "THCTensorMath.h"
#include "THCTensorConv.h"
-#include "THCTensorTopK.h"
#endif
diff --git a/THCAsmUtils.cuh b/THCAsmUtils.cuh
index 7015d20..f0dc90b 100644
--- a/THCAsmUtils.cuh
+++ b/THCAsmUtils.cuh
@@ -3,20 +3,44 @@
// Collection of direct PTX functions
-__device__ __forceinline__
-unsigned int getBitfield(unsigned int val, int pos, int len) {
- unsigned int ret;
- asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
- return ret;
-}
+template <typename T>
+struct Bitfield {};
-__device__ __forceinline__
-unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
- unsigned int ret;
- asm("bfi.b32 %0, %1, %2, %3, %4;" :
- "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
- return ret;
-}
+template <>
+struct Bitfield<unsigned int> {
+ static __device__ __forceinline__
+ unsigned int getBitfield(unsigned int val, int pos, int len) {
+ unsigned int ret;
+ asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
+ return ret;
+ }
+
+ static __device__ __forceinline__
+ unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
+ unsigned int ret;
+ asm("bfi.b32 %0, %1, %2, %3, %4;" :
+ "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
+ return ret;
+ }
+};
+
+template <>
+struct Bitfield<unsigned long long int> {
+ static __device__ __forceinline__
+ unsigned long long int getBitfield(unsigned long long int val, int pos, int len) {
+ unsigned long long int ret;
+ asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
+ return ret;
+ }
+
+ static __device__ __forceinline__
+ unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
+ unsigned long long int ret;
+ asm("bfi.b64 %0, %1, %2, %3, %4;" :
+ "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
+ return ret;
+ }
+};
__device__ __forceinline__ int getLaneId() {
int laneId;
diff --git a/THCTensorMath.h b/THCTensorMath.h
index 8189f4e..b888672 100644
--- a/THCTensorMath.h
+++ b/THCTensorMath.h
@@ -46,6 +46,9 @@
#include "generic/THCTensorMode.h"
#include "THCGenerateAllTypes.h"
+#include "generic/THCTensorTopK.h"
+#include "THCGenerateAllTypes.h"
+
THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
diff --git a/THCTensorTopK.cu b/THCTensorTopK.cu
index e2b817b..325d560 100644
--- a/THCTensorTopK.cu
+++ b/THCTensorTopK.cu
@@ -12,525 +12,8 @@
#include <thrust/system/cuda/execution_policy.h>
#endif
-// Converts a float to an integer representation with the same
-// sorting; i.e., for floats f1, f2:
-// if f1 < f2 then convert(f1) < convert(f2)
-// We use this to enable radix selection of floating-point values.
-// This also gives a relative order for NaNs, but that's ok, as they
-// will all be adjacent
-struct FloatToSortedInt {
- inline __host__ __device__ FloatToSortedInt() {}
+#include "THCTensorTopK.cuh"
- inline __device__ unsigned int convert(float v) const {
- unsigned int x = __float_as_int(v);
- unsigned int mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
-
- return (x ^ mask);
- }
-
- inline __device__ float deconvert(unsigned int v) const {
- unsigned int mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
-
- return __int_as_float(v ^ mask);
- }
-};
-
-// This function counts the distribution of all input values in a
-// slice we are selecting by radix digit at `radixDigitPos`, but only
-// those that pass the filter `((v & desiredMask) == desired)`.
-// This produces and broadcasts the seen counts for a single block only.
-// `smem` must have at least `RadixSize` elements.
-template <typename DataType, typename BitDataType,
- typename IndexType, typename CountType,
- typename RadixConverter, int RadixSize, int RadixBits>
-__device__ void countRadixUsingMask(const RadixConverter& conv,
- CountType counts[RadixSize],
- CountType* smem,
- BitDataType desired,
- BitDataType desiredMask,
- int radixDigitPos,
- IndexType sliceSize,
- IndexType withinSliceStride,
- DataType* data) {
- // Clear out per-thread counts from a previous round
-#pragma unroll
- for (int i = 0; i < RadixSize; ++i) {
- counts[i] = 0;
- }
-
- if (threadIdx.x < RadixSize) {
- smem[threadIdx.x] = 0;
- }
- __syncthreads();
-
- // Scan over all the data. Upon a read, the warp will accumulate
- // counts per each digit in the radix using warp voting.
- for (IndexType i = threadIdx.x; i < sliceSize; i += blockDim.x) {
- BitDataType val = conv.convert(doLdg(&data[i * withinSliceStride]));
-
- bool hasVal = ((val & desiredMask) == desired);
- unsigned int digitInRadix = getBitfield(val, radixDigitPos, RadixBits);
-
-#pragma unroll
- for (unsigned int j = 0; j < RadixSize; ++j) {
- bool vote = hasVal && (digitInRadix == j);
- counts[j] += __popc(__ballot(vote));
- }
- }
-
- // Now, for each warp, sum values
- if (getLaneId() == 0) {
-#pragma unroll
- for (unsigned int i = 0; i < RadixSize; ++i) {
- atomicAdd(&smem[i], counts[i]);
- }
- }
-
- __syncthreads();
-
- // For each thread, read in the total counts
-#pragma unroll
- for (unsigned int i = 0; i < RadixSize; ++i) {
- counts[i] = smem[i];
- }
-
- __syncthreads();
-}
-
-// Over what radix we are selecting values
-#define RADIX_BITS 2 // digits are base-(2 ^ RADIX_BITS)
-#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
-#define RADIX_MASK (RADIX_SIZE - 1)
-
-// This finds the unique value `v` that matches the pattern
-// ((v & desired) == desiredMask) in our sorted int format
-template <typename DataType, typename IndexType, typename RadixConverter>
-__device__ float findPattern(const RadixConverter& conv,
- DataType* smem,
- DataType* data,
- IndexType sliceSize,
- IndexType withinSliceStride,
- unsigned int desired,
- unsigned int desiredMask) {
- if (threadIdx.x < 32) {
- smem[threadIdx.x] = (DataType) 0;
- }
- __syncthreads();
-
- // All threads participate in the loop, in order to sync on the flag
- IndexType numIterations = THCRoundUp(sliceSize, (IndexType) blockDim.x);
- for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
- bool inRange = (i < sliceSize);
- DataType v = inRange ? doLdg(&data[i * withinSliceStride]) : (DataType) 0;
-
- if (inRange && ((conv.convert(v) & desiredMask) == desired)) {
- // There should not be conflicts if we are using findPattern,
- // since the result is unique
- smem[0] = (DataType) 1;
- smem[1] = v; // can't use val as the flag, since it could be 0
- }
-
- __syncthreads();
-
- DataType found = smem[0];
- DataType val = smem[1];
-
- __syncthreads();
-
- // Check to see if a thread found the value
- if (found != (DataType) 0) {
- // all threads return this value
- return val;
- }
- }
-
- // should not get here
- assert(false);
- return (DataType) 0;
-}
-
-// Returns the top-Kth element found in the data using radix selection
-template <typename DataType, typename BitDataType, typename IndexType,
- typename RadixConverter, bool Order>
-__device__ void radixSelect(const RadixConverter& conv,
- DataType* data,
- IndexType k,
- IndexType sliceSize,
- IndexType withinSliceStride,
- int* smem,
- DataType* topK) {
- // Per-thread buckets into which we accumulate digit counts in our
- // radix
- int counts[RADIX_SIZE];
-
- // We only consider elements x such that (x & desiredMask) == desired
- // Initially, we consider all elements of the array, so the above
- // statement is true regardless of input.
- unsigned int desired = 0;
- unsigned int desiredMask = 0;
-
- // We are looking for the top kToFind-th element when iterating over
- // digits; this count gets reduced by elimination when counting
- // successive digits
- int kToFind = k;
-
- // We start at the most significant digit in our radix, scanning
- // through to the least significant digit
-#pragma unroll
- for (int digitPos = sizeof(BitDataType) * 8 - RADIX_BITS;
- digitPos >= 0;
- digitPos -= RADIX_BITS) {
-
- // Count radix distribution for the current position and reduce
- // across all threads
- countRadixUsingMask<DataType, BitDataType,
- IndexType, int, RadixConverter,
- RADIX_SIZE, RADIX_BITS>(
- conv, counts, smem,
- desired, desiredMask, digitPos,
- sliceSize, withinSliceStride, data);
-
- // All threads participate in the comparisons below to know the
- // final result
-
-#define CHECK_RADIX(i) \
- int count = counts[i]; \
- \
- /* All threads have the same value in counts here, so all */ \
- /* threads will return from the function. */ \
- if (count == 1 && kToFind == 1) { \
- /* There is a unique answer. */ \
- desired = setBitfield(desired, i, digitPos, RADIX_BITS); \
- desiredMask = \
- setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
- \
- /* The answer is now the unique element v such that: */ \
- /* (v & desiredMask) == desired */ \
- /* However, we do not yet know what the actual element is. We */ \
- /* need to perform a search through the data to find the */ \
- /* element that matches this pattern. */ \
- *topK = findPattern<DataType, IndexType, RadixConverter>( \
- conv, (float*) smem, data, sliceSize, \
- withinSliceStride, desired, desiredMask); \
- return; \
- } \
- \
- if (count >= kToFind) { \
- desired = setBitfield(desired, i, digitPos, RADIX_BITS); \
- desiredMask = \
- setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
- \
- /* The top-Kth element v must now be one such that: */ \
- /* (v & desiredMask == desired) */ \
- /* but we haven't narrowed it down; we must check the next */ \
- /* least-significant digit */ \
- break; \
- } \
- \
- kToFind -= count \
-
- if (Order) {
- // Process in descending order
-#pragma unroll
- for (int i = RADIX_SIZE - 1; i >= 0; --i) {
- CHECK_RADIX(i);
- }
- } else {
- // Process in ascending order
-#pragma unroll
- for (int i = 0; i < RADIX_SIZE; ++i) {
- CHECK_RADIX(i);
- }
- }
-#undef CHECK_RADIX
- } // end digitPos for
-
- // There is no unique result, but there is a non-unique result
- // matching `desired` exactly
- *topK = conv.deconvert(desired);
-}
-
-template <typename IndexType, int Dim, bool Order>
-__global__ void gatherTopK(TensorInfo<float, IndexType> input,
- IndexType inputSliceSize,
- IndexType outputSliceSize, // aka `k`
-
- IndexType numInputSlices,
- IndexType inputWithinSliceStride,
-
- TensorInfo<float, IndexType> topK,
- IndexType numTopKSlices,
- IndexType topKWithinSliceStride,
-
- TensorInfo<long, IndexType> indices,
- IndexType indicesWithinSliceStride) {
- // Indices are limited to integer fp precision, so counts can fit in
- // int32, regardless of IndexType
- __shared__ int smem[32]; // one per each warp, up to warp limit
-
- IndexType slice = getLinearBlockId<IndexType>();
- if (slice >= numInputSlices) {
- return;
- }
-
- // Find the start offset for our slice
- IndexType sliceStartIndex =
- IndexToOffset<float, IndexType, Dim>::get(slice, input);
- IndexType topKSliceStartIndex =
- IndexToOffset<float, IndexType, Dim>::get(slice, topK);
- IndexType indicesSliceStartIndex =
- IndexToOffset<long, IndexType, Dim>::get(slice, indices);
-
- float* inputSliceStart = &input.data[sliceStartIndex];
- float* topKSliceStart = &topK.data[topKSliceStartIndex];
- long* indicesSliceStart = &indices.data[indicesSliceStartIndex];
-
- // Find the k-th highest element in our input
- float topKValue = -1.0f;
- radixSelect<float, unsigned int, IndexType, FloatToSortedInt, Order>(
- FloatToSortedInt(),
- inputSliceStart, outputSliceSize,
- inputSliceSize, inputWithinSliceStride,
- smem, &topKValue);
-
- // Every value that is strictly less/greater than `pattern`
- // (depending on sort dir) in sorted int format is in the top-K.
- // The top-K value itself might not be unique.
- //
- // Since there are a variable number of elements that we see that
- // are within the top-k, we don't know at what index to write out
- // the resulting values.
- // In order to get this, we perform an exclusive prefix sum of
- // `hasTopK`. This will return the resulting index into which we
- // need to write the result, if a thread has a result.
-
- // All threads need to participate in the loop and the prefix sum,
- // but not necessarily in the load; hence loop bounds being rounded
- // up to a multiple of the block dim.
- IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
- IndexType writeIndexStart = 0;
-
- for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
- bool inRange = (i < inputSliceSize);
- float v =
- inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
- bool hasTopK;
- if (Order) {
- hasTopK = inRange && (v > topKValue);
- } else {
- hasTopK = inRange && (v < topKValue);
- }
-
- int index;
- int carry;
- exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
-
- if (hasTopK) {
- int writeIndex = writeIndexStart + index;
- assert(writeIndex < outputSliceSize);
-
- IndexType topKOffset = writeIndex * topKWithinSliceStride;
- IndexType indexOffset = writeIndex * indicesWithinSliceStride;
-
- topKSliceStart[topKOffset] = v;
- indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
- }
-
- writeIndexStart += carry;
- }
-
- // We need to fill in the rest with actual == top-K values.
- // The number that we need is outputSliceSize -
- // writeIndexStart. There might be more than that number available,
- // in which case we have to choose the first seen set. We do this
- // via a prefix sum to calculate indices for writing results.
- assert(outputSliceSize >= writeIndexStart);
- IndexType topKRemaining = (outputSliceSize - writeIndexStart);
-
- for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
- bool inRange = (i < inputSliceSize);
- float v =
- inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
- bool hasTopK = inRange && (v == topKValue);
-
- int index;
- int carry;
- exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
-
- if (hasTopK && index < topKRemaining) {
- int writeIndex = writeIndexStart + index;
- assert(writeIndex < outputSliceSize);
-
- IndexType topKOffset = writeIndex * topKWithinSliceStride;
- IndexType indexOffset = writeIndex * indicesWithinSliceStride;
-
- topKSliceStart[topKOffset] = v;
- indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
- }
-
- if (carry >= topKRemaining) {
- break;
- }
-
- topKRemaining -= carry;
- writeIndexStart += carry;
- }
-}
-
-#undef RADIX_BITS
-#undef RADIX_SIZE
-#undef RADIX_MASK
-
-THC_API void THCudaTensor_topk(THCState* state,
- THCudaTensor *topK,
- THCudaLongTensor *indices,
- THCudaTensor *input,
- long k, int dim, int dir, int sorted) {
- THAssert(topK != NULL && indices != NULL && input != NULL);
- THCAssertSameGPU(THCudaTensor_checkGPU(state, 3, topK, indices, input));
- THCCheckTensorDims(state, topK, 2);
- long dims = THCudaLongTensor_nDimension(state, indices);
- THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
- THCCheckTensorDims(state, input, 2);
-
- int numDims = THCudaTensor_nDimension(state, input);
- THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range");
-
- long sliceSize = THCudaTensor_size(state, input, dim);
- THArgCheck(k > 0 && k <= sliceSize, 2, "k not in range for dimension");
-
- // Build the output size, which is the dim being selected set to
- // size k
- THLongStorage* topKSize = THCudaTensor_newSizeOf(state, input);
- THLongStorage_set(topKSize, dim, k);
- THCudaTensor_resize(state, topK, topKSize, NULL);
- THCudaLongTensor_resize(state, indices, topKSize, NULL);
- THLongStorage_free(topKSize);
-
-#define RUN_K(INDEX_T, DIM, DIR) \
- gatherTopK<INDEX_T, DIM, DIR> \
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
- inputInfo, \
- sliceSize, \
- k, \
- inputSlices, \
- /* The actual dimension that the k-selection is running in */ \
- /* may have changed from collapseDims() */ \
- inputInfo.strides[collapseInputDim], \
- topKInfo, \
- topKSlices, \
- topKInfo.strides[collapseTopKDim], \
- indicesInfo, \
- indicesInfo.strides[collapseIndicesDim])
-
-#define RUN_DIR(INDEX_T, DIM) \
- if (dir) { \
- RUN_K(INDEX_T, DIM, true); \
- } else { \
- RUN_K(INDEX_T, DIM, false); \
- }
-
-#define RUN_DIM(INDEX_T) \
- if (allDims == 1) { \
- RUN_DIR(INDEX_T, 1); \
- } else if (allDims == 2) { \
- RUN_DIR(INDEX_T, 2); \
- } else if (allDims == 3) { \
- RUN_DIR(INDEX_T, 3); \
- } else { \
- RUN_DIR(INDEX_T, -1); \
- }
-
-#define RUN_T(INDEX_T) \
- TensorInfo<float, INDEX_T> inputInfo = \
- getTensorInfo<THCudaTensor, INDEX_T>(state, input); \
- TensorInfo<float, INDEX_T> topKInfo = \
- getTensorInfo<THCudaTensor, INDEX_T>(state, topK); \
- TensorInfo<long, INDEX_T> indicesInfo = \
- getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices); \
- \
- /* We use these structures solely to find the offset to */ \
- /* each slice we are operating on */ \
- inputInfo.sizes[dim] = 1; \
- topKInfo.sizes[dim] = 1; \
- indicesInfo.sizes[dim] = 1; \
- \
- /* Collapse all other dims */ \
- int collapseInputDim = inputInfo.collapseDims(dim); \
- int collapseTopKDim = topKInfo.collapseDims(dim); \
- int collapseIndicesDim = indicesInfo.collapseDims(dim); \
- \
- long inputSlices = 1; \
- long topKSlices = 1; \
- for (int i = 0; i < numDims; ++i) { \
- inputSlices *= inputInfo.sizes[i]; \
- topKSlices *= topKInfo.sizes[i]; \
- } \
- \
- dim3 grid; \
- if (!THC_getGridFromTiles(inputSlices, grid)) { \
- THError("Slice to sort is too large"); \
- } \
- \
- dim3 block(std::min(THCRoundUp(sliceSize, 32L), 1024L)); \
- \
- /* This is used as a template parameter to calculate indices. */ \
- /* We only specialize it if all collapsed dim sizes are the */ \
- /* same; otherwise, we use -1 which is the specialization */ \
- /* parameter for arbitrary dimensions */ \
- int allDims = inputInfo.dims; \
- if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \
- allDims = -1; \
- } \
- \
- RUN_DIM(INDEX_T);
-
- // Based on required index size, run the algorithm with the
- // appropriate index type
- if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input) &&
- TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, topK) &&
- TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, indices)) {
- RUN_T(unsigned int);
- } else {
- RUN_T(unsigned long);
- }
-#undef RUN_T
-#undef RUN_DIM
-#undef RUN_DIR
-#undef RUN_K
-
- // Sort the results if the user wants them sorted, since our
- // selection routine does not ensure sorting
- if (sorted) {
- // FIXME: the k/v inplace sort along slice only works for size <=
- // 2048 at the moment
- if (sliceSize <= 2048) {
- // This avoids any memory allocations and performs all sorting
- // work inplace along the slice
- THCudaTensor_sortKeyValueInplace(state, topK, indices, dim, dir);
- } else {
- // Depend upon the backup sort that returns indices, which we
- // can use in conjunction with gather to produce the original
- // indices.
- // This is not the most efficient implementation, especially since
- // there are memory allocations performed here. If the user desires
- // greater performance, they should torch.gather() the results
- // themselves using the reported indices, providing previously
- // allocated tensors to receive the results.
- THCudaTensor* sortedTopK = THCudaTensor_new(state);
- THCudaLongTensor* sortedIndices = THCudaLongTensor_new(state);
- THCudaTensor_sort(state, sortedTopK, sortedIndices, topK, dim, dir);
-
- THCudaLongTensor* sortedTopKIndices = THCudaLongTensor_new(state);
-
- THCudaLongTensor_resizeAs(state, sortedTopKIndices, indices);
- THCudaLongTensor_gather(state, sortedTopKIndices, indices, dim, sortedIndices);
-
- THCudaTensor_freeCopyTo(state, sortedTopK, topK);
- THCudaLongTensor_freeCopyTo(state, sortedTopKIndices, indices);
- THCudaLongTensor_free(state, sortedIndices);
- }
- }
+#include "generic/THCTensorTopK.cu"
+#include "THCGenerateAllTypes.h"
- THCudaCheck(cudaGetLastError());
-}
diff --git a/THCTensorTopK.cuh b/THCTensorTopK.cuh
new file mode 100644
index 0000000..32041e3
--- /dev/null
+++ b/THCTensorTopK.cuh
@@ -0,0 +1,473 @@
+#ifndef THC_TENSOR_TOPK_CUH
+#define THC_TENSOR_TOPK_CUH
+
+template <typename T>
+struct TopKTypeConfig {};
+
+template <>
+struct TopKTypeConfig<float> {
+ typedef unsigned int RadixType;
+
+ // Converts a float to an integer representation with the same
+ // sorting; i.e., for floats f1, f2:
+ // if f1 < f2 then convert(f1) < convert(f2)
+ // We use this to enable radix selection of floating-point values.
+ // This also gives a relative order for NaNs, but that's ok, as they
+ // will all be adjacent
+ static inline __device__ RadixType convert(float v) {
+ RadixType x = __float_as_int(v);
+ RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
+
+ return (x ^ mask);
+ }
+
+ static inline __device__ float deconvert(RadixType v) {
+ RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
+
+ return __int_as_float(v ^ mask);
+ }
+};
+
+template <>
+struct TopKTypeConfig<unsigned char> {
+ typedef unsigned int RadixType;
+
+ static inline __device__ RadixType convert(unsigned char v) {
+ return v;
+ }
+
+ static inline __device__ unsigned char deconvert(RadixType v) {
+ return v;
+ }
+};
+
+template <>
+struct TopKTypeConfig<char> {
+ typedef unsigned int RadixType;
+
+ static inline __device__ RadixType convert(char v) {
+ return 128u + v;
+ }
+
+ static inline __device__ char deconvert(RadixType v) {
+ return v - 128;
+ }
+};
+
+template <>
+struct TopKTypeConfig<short> {
+ typedef unsigned int RadixType;
+
+ static inline __device__ RadixType convert(short v) {
+ assert(sizeof(short) == 2);
+ return 32768u + v;
+ }
+
+ static inline __device__ short deconvert(RadixType v) {
+ return v - 32768;
+ }
+};
+
+template <>
+struct TopKTypeConfig<int> {
+ typedef unsigned int RadixType;
+
+ static inline __device__ RadixType convert(int v) {
+ assert(sizeof(int) == 4);
+ return 2147483648u + v;
+ }
+
+ static inline __device__ int deconvert(RadixType v) {
+ return v - 2147483648u;
+ }
+};
+
+template <>
+struct TopKTypeConfig<long> {
+ typedef unsigned long long int RadixType;
+
+ static inline __device__ RadixType convert(long v) {
+ assert(sizeof(long) == 8);
+ return 9223372036854775808ull + v;
+ }
+
+ static inline __device__ long deconvert(RadixType v) {
+ return v - 9223372036854775808ull;
+ }
+};
+
+template <>
+struct TopKTypeConfig<double> {
+ typedef unsigned long long int RadixType;
+
+ static inline __device__ RadixType convert(double v) {
+ RadixType x = __double_as_longlong(v);
+ RadixType mask = -((x >> 63)) | 0x8000000000000000;
+ return (x ^ mask);
+ }
+
+ static inline __device__ double deconvert(RadixType v) {
+ RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
+ return __longlong_as_double(v ^ mask);
+ }
+};
+
+template <>
+struct TopKTypeConfig<half> {
+ typedef unsigned int RadixType;
+
+ static inline __device__ RadixType convert(half v) {
+ RadixType x = __half_as_ushort(v);
+ RadixType mask = -((x >> 15)) | 0x8000;
+ return (x ^ mask);
+ }
+
+ static inline __device__ half deconvert(RadixType v) {
+ RadixType mask = ((v >> 15) - 1) | 0x8000;
+ return __ushort_as_half(v ^ mask);
+ }
+};
+
+// This function counts the distribution of all input values in a
+// slice we are selecting by radix digit at `radixDigitPos`, but only
+// those that pass the filter `((v & desiredMask) == desired)`.
+// This produces and broadcasts the seen counts for a single block only.
+// `smem` must have at least `RadixSize` elements.
+template <typename DataType, typename BitDataType,
+ typename IndexType, typename CountType,
+ int RadixSize, int RadixBits>
+__device__ void countRadixUsingMask(CountType counts[RadixSize],
+ CountType* smem,
+ BitDataType desired,
+ BitDataType desiredMask,
+ int radixDigitPos,
+ IndexType sliceSize,
+ IndexType withinSliceStride,
+ DataType* data) {
+ // Clear out per-thread counts from a previous round
+#pragma unroll
+ for (int i = 0; i < RadixSize; ++i) {
+ counts[i] = 0;
+ }
+
+ if (threadIdx.x < RadixSize) {
+ smem[threadIdx.x] = 0;
+ }
+ __syncthreads();
+
+ // Scan over all the data. Upon a read, the warp will accumulate
+ // counts per each digit in the radix using warp voting.
+ for (IndexType i = threadIdx.x; i < sliceSize; i += blockDim.x) {
+ BitDataType val = TopKTypeConfig<DataType>::convert(doLdg(&data[i * withinSliceStride]));
+
+ bool hasVal = ((val & desiredMask) == desired);
+ BitDataType digitInRadix = Bitfield<BitDataType>::getBitfield(val, radixDigitPos, RadixBits);
+
+#pragma unroll
+ for (unsigned int j = 0; j < RadixSize; ++j) {
+ bool vote = hasVal && (digitInRadix == j);
+ counts[j] += __popc(__ballot(vote));
+ }
+ }
+
+ // Now, for each warp, sum values
+ if (getLaneId() == 0) {
+#pragma unroll
+ for (unsigned int i = 0; i < RadixSize; ++i) {
+ atomicAdd(&smem[i], counts[i]);
+ }
+ }
+
+ __syncthreads();
+
+ // For each thread, read in the total counts
+#pragma unroll
+ for (unsigned int i = 0; i < RadixSize; ++i) {
+ counts[i] = smem[i];
+ }
+
+ __syncthreads();
+}
+
+// Over what radix we are selecting values
+#define RADIX_BITS 2 // digits are base-(2 ^ RADIX_BITS)
+#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
+#define RADIX_MASK (RADIX_SIZE - 1)
+
+// This finds the unique value `v` that matches the pattern
+// ((v & desired) == desiredMask) in our sorted int format
+template <typename DataType, typename BitDataType, typename IndexType>
+__device__ DataType findPattern(DataType* smem,
+ DataType* data,
+ IndexType sliceSize,
+ IndexType withinSliceStride,
+ BitDataType desired,
+ BitDataType desiredMask) {
+ if (threadIdx.x < 32) {
+ smem[threadIdx.x] = ScalarConvert<int, DataType>::to(0);
+ }
+ __syncthreads();
+
+ // All threads participate in the loop, in order to sync on the flag
+ IndexType numIterations = THCRoundUp(sliceSize, (IndexType) blockDim.x);
+ for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+ bool inRange = (i < sliceSize);
+ DataType v = inRange ? doLdg(&data[i * withinSliceStride]) : ScalarConvert<int, DataType>::to(0);
+
+ if (inRange && ((TopKTypeConfig<DataType>::convert(v) & desiredMask) == desired)) {
+ // There should not be conflicts if we are using findPattern,
+ // since the result is unique
+ smem[0] = ScalarConvert<int, DataType>::to(1);
+ smem[1] = v; // can't use val as the flag, since it could be 0
+ }
+
+ __syncthreads();
+
+ DataType found = smem[0];
+ DataType val = smem[1];
+
+ __syncthreads();
+
+ // Check to see if a thread found the value
+ if (THCNumerics<DataType>::ne(found, ScalarConvert<int, DataType>::to(0))) {
+ // all threads return this value
+ return val;
+ }
+ }
+
+ // should not get here
+ assert(false);
+ return ScalarConvert<int, DataType>::to(0);
+}
+
+// Returns the top-Kth element found in the data using radix selection
+template <typename DataType, typename BitDataType, typename IndexType, bool Order>
+__device__ void radixSelect(DataType* data,
+ IndexType k,
+ IndexType sliceSize,
+ IndexType withinSliceStride,
+ int* smem,
+ DataType* topK) {
+ // Per-thread buckets into which we accumulate digit counts in our
+ // radix
+ int counts[RADIX_SIZE];
+
+ // We only consider elements x such that (x & desiredMask) == desired
+ // Initially, we consider all elements of the array, so the above
+ // statement is true regardless of input.
+ BitDataType desired = 0;
+ BitDataType desiredMask = 0;
+
+ // We are looking for the top kToFind-th element when iterating over
+ // digits; this count gets reduced by elimination when counting
+ // successive digits
+ int kToFind = k;
+
+ // We start at the most significant digit in our radix, scanning
+ // through to the least significant digit
+#pragma unroll
+ for (int digitPos = sizeof(DataType) * 8 - RADIX_BITS;
+ digitPos >= 0;
+ digitPos -= RADIX_BITS) {
+
+ // Count radix distribution for the current position and reduce
+ // across all threads
+ countRadixUsingMask<DataType, BitDataType,
+ IndexType, int,
+ RADIX_SIZE, RADIX_BITS>(
+ counts, smem,
+ desired, desiredMask, digitPos,
+ sliceSize, withinSliceStride, data);
+
+ // All threads participate in the comparisons below to know the
+ // final result
+
+
+#define CHECK_RADIX(i) \
+ int count = counts[i]; \
+ \
+ /* All threads have the same value in counts here, so all */ \
+ /* threads will return from the function. */ \
+ if (count == 1 && kToFind == 1) { \
+ /* There is a unique answer. */ \
+ desired = Bitfield<BitDataType>::setBitfield(desired, i, digitPos, RADIX_BITS); \
+ desiredMask = \
+ Bitfield<BitDataType>::setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
+ \
+ /* The answer is now the unique element v such that: */ \
+ /* (v & desiredMask) == desired */ \
+ /* However, we do not yet know what the actual element is. We */ \
+ /* need to perform a search through the data to find the */ \
+ /* element that matches this pattern. */ \
+ *topK = findPattern<DataType, BitDataType, IndexType>( \
+ (DataType*) smem, data, sliceSize, \
+ withinSliceStride, desired, desiredMask); \
+ return; \
+ } \
+ \
+ if (count >= kToFind) { \
+ desired = Bitfield<BitDataType>::setBitfield(desired, i, digitPos, RADIX_BITS); \
+ desiredMask = \
+ Bitfield<BitDataType>::setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS); \
+ \
+ /* The top-Kth element v must now be one such that: */ \
+ /* (v & desiredMask == desired) */ \
+ /* but we haven't narrowed it down; we must check the next */ \
+ /* least-significant digit */ \
+ break; \
+ } \
+ \
+ kToFind -= count \
+
+ if (Order) {
+ // Process in descending order
+#pragma unroll
+ for (int i = RADIX_SIZE - 1; i >= 0; --i) {
+ CHECK_RADIX(i);
+ }
+ } else {
+ // Process in ascending order
+#pragma unroll
+ for (int i = 0; i < RADIX_SIZE; ++i) {
+ CHECK_RADIX(i);
+ }
+ }
+#undef CHECK_RADIX
+ } // end digitPos for
+
+ // There is no unique result, but there is a non-unique result
+ // matching `desired` exactly
+ *topK = TopKTypeConfig<DataType>::deconvert(desired);
+}
+
+template <typename T, typename IndexType, int Dim, bool Order>
+__global__ void gatherTopK(TensorInfo<T, IndexType> input,
+ IndexType inputSliceSize,
+ IndexType outputSliceSize, // aka `k`
+
+ IndexType numInputSlices,
+ IndexType inputWithinSliceStride,
+
+ TensorInfo<T, IndexType> topK,
+ IndexType numTopKSlices,
+ IndexType topKWithinSliceStride,
+
+ TensorInfo<long, IndexType> indices,
+ IndexType indicesWithinSliceStride) {
+ // Indices are limited to integer fp precision, so counts can fit in
+ // int32, regardless of IndexType
+ __shared__ int smem[32]; // one per each warp, up to warp limit
+
+ IndexType slice = getLinearBlockId<IndexType>();
+ if (slice >= numInputSlices) {
+ return;
+ }
+
+ // Find the start offset for our slice
+ IndexType sliceStartIndex =
+ IndexToOffset<T, IndexType, Dim>::get(slice, input);
+ IndexType topKSliceStartIndex =
+ IndexToOffset<T, IndexType, Dim>::get(slice, topK);
+ IndexType indicesSliceStartIndex =
+ IndexToOffset<long, IndexType, Dim>::get(slice, indices);
+
+ T* inputSliceStart = &input.data[sliceStartIndex];
+ T* topKSliceStart = &topK.data[topKSliceStartIndex];
+ long* indicesSliceStart = &indices.data[indicesSliceStartIndex];
+
+ // Find the k-th highest element in our input
+ T topKValue = ScalarConvert<int, T>::to(0);
+ radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType, Order>(
+ inputSliceStart, outputSliceSize,
+ inputSliceSize, inputWithinSliceStride,
+ smem, &topKValue);
+
+ // Every value that is strictly less/greater than `pattern`
+ // (depending on sort dir) in sorted int format is in the top-K.
+ // The top-K value itself might not be unique.
+ //
+ // Since there are a variable number of elements that we see that
+ // are within the top-k, we don't know at what index to write out
+ // the resulting values.
+ // In order to get this, we perform an exclusive prefix sum of
+ // `hasTopK`. This will return the resulting index into which we
+ // need to write the result, if a thread has a result.
+
+ // All threads need to participate in the loop and the prefix sum,
+ // but not necessarily in the load; hence loop bounds being rounded
+ // up to a multiple of the block dim.
+ IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
+ IndexType writeIndexStart = 0;
+
+ for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+ bool inRange = (i < inputSliceSize);
+ T v =
+ inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
+ bool hasTopK;
+ if (Order) {
+ hasTopK = inRange && (THCNumerics<T>::gt(v, topKValue));
+ } else {
+ hasTopK = inRange && (THCNumerics<T>::lt(v, topKValue));
+ }
+
+ int index;
+ int carry;
+ exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
+
+ if (hasTopK) {
+ int writeIndex = writeIndexStart + index;
+ assert(writeIndex < outputSliceSize);
+
+ IndexType topKOffset = writeIndex * topKWithinSliceStride;
+ IndexType indexOffset = writeIndex * indicesWithinSliceStride;
+
+ topKSliceStart[topKOffset] = v;
+ indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
+ }
+
+ writeIndexStart += carry;
+ }
+
+ // We need to fill in the rest with actual == top-K values.
+ // The number that we need is outputSliceSize -
+ // writeIndexStart. There might be more than that number available,
+ // in which case we have to choose the first seen set. We do this
+ // via a prefix sum to calculate indices for writing results.
+ assert(outputSliceSize >= writeIndexStart);
+ IndexType topKRemaining = (outputSliceSize - writeIndexStart);
+
+ for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+ bool inRange = (i < inputSliceSize);
+ T v =
+ inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
+ bool hasTopK = inRange && (THCNumerics<T>::eq(v, topKValue));
+
+ int index;
+ int carry;
+ exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
+
+ if (hasTopK && index < topKRemaining) {
+ int writeIndex = writeIndexStart + index;
+ assert(writeIndex < outputSliceSize);
+
+ IndexType topKOffset = writeIndex * topKWithinSliceStride;
+ IndexType indexOffset = writeIndex * indicesWithinSliceStride;
+
+ topKSliceStart[topKOffset] = v;
+ indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
+ }
+
+ if (carry >= topKRemaining) {
+ break;
+ }
+
+ topKRemaining -= carry;
+ writeIndexStart += carry;
+ }
+}
+
+#undef RADIX_BITS
+#undef RADIX_SIZE
+#undef RADIX_MASK
+
+#endif // THC_TENSOR_TOPK_CUH
diff --git a/THCTensorTopK.h b/THCTensorTopK.h
deleted file mode 100644
index 711c047..0000000
--- a/THCTensorTopK.h
+++ /dev/null
@@ -1,14 +0,0 @@
-#ifndef TH_CUDA_TENSOR_TOPK_INC
-#define TH_CUDA_TENSOR_TOPK_INC
-
-#include "THCTensor.h"
-
-/* Returns the set of all kth smallest (or largest) elements, depending */
-/* on `dir` */
-THC_API void THCudaTensor_topk(THCState* state,
- THCudaTensor* topK,
- THCudaLongTensor* indices,
- THCudaTensor* input,
- long k, int dim, int dir, int sorted);
-
-#endif
diff --git a/generic/THCTensorTopK.cu b/generic/THCTensorTopK.cu
new file mode 100644
index 0000000..83ab1e1
--- /dev/null
+++ b/generic/THCTensorTopK.cu
@@ -0,0 +1,159 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorTopK.cu"
+#else
+
+THC_API void THCTensor_(topk)(THCState* state,
+ THCTensor *topK,
+ THCudaLongTensor *indices,
+ THCTensor *input,
+ long k, int dim, int dir, int sorted) {
+ THAssert(topK != NULL && indices != NULL && input != NULL);
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input));
+ THArgCheck(THCTensor_(nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
+ long dims = THCudaLongTensor_nDimension(state, indices);
+ THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
+ int numDims = THCTensor_(nDimension)(state, input);
+ THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
+
+ THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");
+
+ long sliceSize = THCTensor_(size)(state, input, dim);
+ THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension");
+
+ // Build the output size, which is the dim being selected set to
+ // size k
+ THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
+ THLongStorage_set(topKSize, dim, k);
+ THCTensor_(resize)(state, topK, topKSize, NULL);
+ THCudaLongTensor_resize(state, indices, topKSize, NULL);
+ THLongStorage_free(topKSize);
+
+#define RUN_K(INDEX_T, DIM, DIR) \
+ gatherTopK<real, INDEX_T, DIM, DIR> \
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
+ inputInfo, \
+ sliceSize, \
+ k, \
+ inputSlices, \
+ /* The actual dimension that the k-selection is running in */ \
+ /* may have changed from collapseDims() */ \
+ inputInfo.strides[collapseInputDim], \
+ topKInfo, \
+ topKSlices, \
+ topKInfo.strides[collapseTopKDim], \
+ indicesInfo, \
+ indicesInfo.strides[collapseIndicesDim])
+
+#define RUN_DIR(INDEX_T, DIM) \
+ if (dir) { \
+ RUN_K(INDEX_T, DIM, true); \
+ } else { \
+ RUN_K(INDEX_T, DIM, false); \
+ }
+
+#define RUN_DIM(INDEX_T) \
+ if (allDims == 1) { \
+ RUN_DIR(INDEX_T, 1); \
+ } else if (allDims == 2) { \
+ RUN_DIR(INDEX_T, 2); \
+ } else if (allDims == 3) { \
+ RUN_DIR(INDEX_T, 3); \
+ } else { \
+ RUN_DIR(INDEX_T, -1); \
+ }
+
+#define RUN_T(INDEX_T) \
+ TensorInfo<real, INDEX_T> inputInfo = \
+ getTensorInfo<THCTensor, INDEX_T>(state, input); \
+ TensorInfo<real, INDEX_T> topKInfo = \
+ getTensorInfo<THCTensor, INDEX_T>(state, topK); \
+ TensorInfo<long, INDEX_T> indicesInfo = \
+ getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices); \
+ \
+ /* We use these structures solely to find the offset to */ \
+ /* each slice we are operating on */ \
+ inputInfo.sizes[dim] = 1; \
+ topKInfo.sizes[dim] = 1; \
+ indicesInfo.sizes[dim] = 1; \
+ \
+ /* Collapse all other dims */ \
+ int collapseInputDim = inputInfo.collapseDims(dim); \
+ int collapseTopKDim = topKInfo.collapseDims(dim); \
+ int collapseIndicesDim = indicesInfo.collapseDims(dim); \
+ \
+ long inputSlices = 1; \
+ long topKSlices = 1; \
+ for (int i = 0; i < numDims; ++i) { \
+ inputSlices *= inputInfo.sizes[i]; \
+ topKSlices *= topKInfo.sizes[i]; \
+ } \
+ \
+ dim3 grid; \
+ if (!THC_getGridFromTiles(inputSlices, grid)) { \
+ THError("Slice to sort is too large"); \
+ } \
+ \
+ dim3 block(std::min(THCRoundUp(sliceSize, 32L), 1024L)); \
+ \
+ /* This is used as a template parameter to calculate indices. */ \
+ /* We only specialize it if all collapsed dim sizes are the */ \
+ /* same; otherwise, we use -1 which is the specialization */ \
+ /* parameter for arbitrary dimensions */ \
+ int allDims = inputInfo.dims; \
+ if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \
+ allDims = -1; \
+ } \
+ \
+ RUN_DIM(INDEX_T);
+
+ // Based on required index size, run the algorithm with the
+ // appropriate index type
+ if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, input) &&
+ TensorUtils<THCTensor>::canUse32BitIndexMath(state, topK) &&
+ TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, indices)) {
+ RUN_T(unsigned int);
+ } else {
+ RUN_T(unsigned long);
+ }
+#undef RUN_T
+#undef RUN_DIM
+#undef RUN_DIR
+#undef RUN_K
+
+ // Sort the results if the user wants them sorted, since our
+ // selection routine does not ensure sorting
+ if (sorted) {
+ // FIXME: the k/v inplace sort along slice only works for size <=
+ // 2048 at the moment
+ if (sliceSize <= 2048) {
+ // This avoids any memory allocations and performs all sorting
+ // work inplace along the slice
+ THCTensor_(sortKeyValueInplace)(state, topK, indices, dim, dir);
+ } else {
+ // Depend upon the backup sort that returns indices, which we
+ // can use in conjunction with gather to produce the original
+ // indices.
+ // This is not the most efficient implementation, especially since
+ // there are memory allocations performed here. If the user desires
+ // greater performance, they should torch.gather() the results
+ // themselves using the reported indices, providing previously
+ // allocated tensors to receive the results.
+ THCTensor* sortedTopK = THCTensor_(new)(state);
+ THCudaLongTensor* sortedIndices = THCudaLongTensor_new(state);
+ THCTensor_(sort)(state, sortedTopK, sortedIndices, topK, dim, dir);
+
+ THCudaLongTensor* sortedTopKIndices = THCudaLongTensor_new(state);
+
+ THCudaLongTensor_resizeAs(state, sortedTopKIndices, indices);
+ THCudaLongTensor_gather(state, sortedTopKIndices, indices, dim, sortedIndices);
+
+ THCTensor_(freeCopyTo)(state, sortedTopK, topK);
+ THCudaLongTensor_freeCopyTo(state, sortedTopKIndices, indices);
+ THCudaLongTensor_free(state, sortedIndices);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+#endif // THC_GENERIC_FILE
diff --git a/generic/THCTensorTopK.h b/generic/THCTensorTopK.h
new file mode 100644
index 0000000..2c281b5
--- /dev/null
+++ b/generic/THCTensorTopK.h
@@ -0,0 +1,13 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorTopK.h"
+#else
+
+/* Returns the set of all kth smallest (or largest) elements, depending */
+/* on `dir` */
+THC_API void THCTensor_(topk)(THCState* state,
+ THCTensor* topK,
+ THCudaLongTensor* indices,
+ THCTensor* input,
+ long k, int dim, int dir, int sorted);
+
+#endif // THC_GENERIC_FILE