Generic TopK implementation (#744)

* move TopK to generic

* partial genericization of kernel code

* introduce TopKTypeConfig, specialize radix type and conversion for floats

* implement topk for byte tensor

* implement for char tensor

* implement for int tensor, extend test to check indices as well

* works for longs too

* make bitfield set/get a struct, add support for 64-bit types

* extend to double tensor

* implement for half tensor

* asserts; test fix
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 935098a..1ea6039 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -258,7 +258,6 @@
           THCTensorRandom.h
           THCTensorMath.h
           THCTensorConv.h
-          THCTensorTopK.h
           THCApply.cuh
           THCReduce.cuh
           THCReduceAll.cuh
@@ -295,6 +294,7 @@
           THCTensorMathMagma.cuh
           THCThrustAllocator.cuh
           THCTensorMode.cuh
+          THCTensorTopK.cuh
           DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC")
 
 INSTALL(FILES
@@ -341,4 +341,6 @@
           generic/THCTensorRandom.cu
           generic/THCTensorMode.h
           generic/THCTensorMode.cu
+          generic/THCTensorTopK.h
+          generic/THCTensorTopK.cu
           DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
diff --git a/THC.h b/THC.h
index e3840dc..90a3a53 100644
--- a/THC.h
+++ b/THC.h
@@ -15,6 +15,5 @@
 #include "THCTensorRandom.h"
 #include "THCTensorMath.h"
 #include "THCTensorConv.h"
-#include "THCTensorTopK.h"
 
 #endif
diff --git a/THCAsmUtils.cuh b/THCAsmUtils.cuh
index 7015d20..f0dc90b 100644
--- a/THCAsmUtils.cuh
+++ b/THCAsmUtils.cuh
@@ -3,20 +3,44 @@
 
 // Collection of direct PTX functions
 
-__device__ __forceinline__
-unsigned int getBitfield(unsigned int val, int pos, int len) {
-  unsigned int ret;
-  asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
-  return ret;
-}
+template <typename T>
+struct Bitfield {};
 
-__device__ __forceinline__
-unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
-  unsigned int ret;
-  asm("bfi.b32 %0, %1, %2, %3, %4;" :
-      "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
-  return ret;
-}
+template <>
+struct Bitfield<unsigned int> {
+  static __device__ __forceinline__
+  unsigned int getBitfield(unsigned int val, int pos, int len) {
+    unsigned int ret;
+    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
+    return ret;
+  }
+
+  static __device__ __forceinline__
+  unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
+    unsigned int ret;
+    asm("bfi.b32 %0, %1, %2, %3, %4;" :
+        "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
+    return ret;
+  }
+};
+
+template <>
+struct Bitfield<unsigned long long int> {
+  static __device__ __forceinline__
+  unsigned long long int getBitfield(unsigned long long int val, int pos, int len) {
+    unsigned long long int ret;
+    asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
+    return ret;
+  }
+
+  static __device__ __forceinline__
+  unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) {
+    unsigned long long int ret;
+    asm("bfi.b64 %0, %1, %2, %3, %4;" :
+        "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
+    return ret;
+  }
+};
 
 __device__ __forceinline__ int getLaneId() {
   int laneId;
diff --git a/THCTensorMath.h b/THCTensorMath.h
index 8189f4e..b888672 100644
--- a/THCTensorMath.h
+++ b/THCTensorMath.h
@@ -46,6 +46,9 @@
 #include "generic/THCTensorMode.h"
 #include "THCGenerateAllTypes.h"
 
+#include "generic/THCTensorTopK.h"
+#include "THCGenerateAllTypes.h"
+
 THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
 THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
 
diff --git a/THCTensorTopK.cu b/THCTensorTopK.cu
index e2b817b..325d560 100644
--- a/THCTensorTopK.cu
+++ b/THCTensorTopK.cu
@@ -12,525 +12,8 @@
 #include <thrust/system/cuda/execution_policy.h>
 #endif
 
-// Converts a float to an integer representation with the same
-// sorting; i.e., for floats f1, f2:
-// if f1 < f2 then convert(f1) < convert(f2)
-// We use this to enable radix selection of floating-point values.
-// This also gives a relative order for NaNs, but that's ok, as they
-// will all be adjacent
-struct FloatToSortedInt {
-  inline __host__ __device__ FloatToSortedInt() {}
+#include "THCTensorTopK.cuh"
 
-  inline __device__ unsigned int convert(float v) const {
-    unsigned int x = __float_as_int(v);
-    unsigned int mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
-
-    return (x ^ mask);
-  }
-
-  inline __device__ float deconvert(unsigned int v) const {
-    unsigned int mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
-
-    return __int_as_float(v ^ mask);
-  }
-};
-
-// This function counts the distribution of all input values in a
-// slice we are selecting by radix digit at `radixDigitPos`, but only
-// those that pass the filter `((v & desiredMask) == desired)`.
-// This produces and broadcasts the seen counts for a single block only.
-// `smem` must have at least `RadixSize` elements.
-template <typename DataType, typename BitDataType,
-          typename IndexType, typename CountType,
-          typename RadixConverter, int RadixSize, int RadixBits>
-__device__ void countRadixUsingMask(const RadixConverter& conv,
-                                    CountType counts[RadixSize],
-                                    CountType* smem,
-                                    BitDataType desired,
-                                    BitDataType desiredMask,
-                                    int radixDigitPos,
-                                    IndexType sliceSize,
-                                    IndexType withinSliceStride,
-                                    DataType* data) {
-  // Clear out per-thread counts from a previous round
-#pragma unroll
-  for (int i = 0; i < RadixSize; ++i) {
-    counts[i] = 0;
-  }
-
-  if (threadIdx.x < RadixSize) {
-    smem[threadIdx.x] = 0;
-  }
-  __syncthreads();
-
-  // Scan over all the data. Upon a read, the warp will accumulate
-  // counts per each digit in the radix using warp voting.
-  for (IndexType i = threadIdx.x; i < sliceSize; i += blockDim.x) {
-    BitDataType val = conv.convert(doLdg(&data[i * withinSliceStride]));
-
-    bool hasVal = ((val & desiredMask) == desired);
-    unsigned int digitInRadix = getBitfield(val, radixDigitPos, RadixBits);
-
-#pragma unroll
-    for (unsigned int j = 0; j < RadixSize; ++j) {
-      bool vote = hasVal && (digitInRadix == j);
-      counts[j] += __popc(__ballot(vote));
-    }
-  }
-
-  // Now, for each warp, sum values
-  if (getLaneId() == 0) {
-#pragma unroll
-    for (unsigned int i = 0; i < RadixSize; ++i) {
-      atomicAdd(&smem[i], counts[i]);
-    }
-  }
-
-  __syncthreads();
-
-  // For each thread, read in the total counts
-#pragma unroll
-  for (unsigned int i = 0; i < RadixSize; ++i) {
-    counts[i] = smem[i];
-  }
-
-  __syncthreads();
-}
-
-// Over what radix we are selecting values
-#define RADIX_BITS 2 // digits are base-(2 ^ RADIX_BITS)
-#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
-#define RADIX_MASK (RADIX_SIZE - 1)
-
-// This finds the unique value `v` that matches the pattern
-// ((v & desired) == desiredMask) in our sorted int format
-template <typename DataType, typename IndexType, typename RadixConverter>
-__device__ float findPattern(const RadixConverter& conv,
-                             DataType* smem,
-                             DataType* data,
-                             IndexType sliceSize,
-                             IndexType withinSliceStride,
-                             unsigned int desired,
-                             unsigned int desiredMask) {
-  if (threadIdx.x < 32) {
-    smem[threadIdx.x] = (DataType) 0;
-  }
-  __syncthreads();
-
-  // All threads participate in the loop, in order to sync on the flag
-  IndexType numIterations = THCRoundUp(sliceSize, (IndexType) blockDim.x);
-  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
-    bool inRange = (i < sliceSize);
-    DataType v = inRange ? doLdg(&data[i * withinSliceStride]) : (DataType) 0;
-
-    if (inRange && ((conv.convert(v) & desiredMask) == desired)) {
-      // There should not be conflicts if we are using findPattern,
-      // since the result is unique
-      smem[0] = (DataType) 1;
-      smem[1] = v; // can't use val as the flag, since it could be 0
-    }
-
-    __syncthreads();
-
-    DataType found = smem[0];
-    DataType val = smem[1];
-
-    __syncthreads();
-
-    // Check to see if a thread found the value
-    if (found != (DataType) 0) {
-      // all threads return this value
-      return val;
-    }
-  }
-
-  // should not get here
-  assert(false);
-  return (DataType) 0;
-}
-
-// Returns the top-Kth element found in the data using radix selection
-template <typename DataType, typename BitDataType, typename IndexType,
-          typename RadixConverter, bool Order>
-__device__ void radixSelect(const RadixConverter& conv,
-                            DataType* data,
-                            IndexType k,
-                            IndexType sliceSize,
-                            IndexType withinSliceStride,
-                            int* smem,
-                            DataType* topK) {
-  // Per-thread buckets into which we accumulate digit counts in our
-  // radix
-  int counts[RADIX_SIZE];
-
-  // We only consider elements x such that (x & desiredMask) == desired
-  // Initially, we consider all elements of the array, so the above
-  // statement is true regardless of input.
-  unsigned int desired = 0;
-  unsigned int desiredMask = 0;
-
-  // We are looking for the top kToFind-th element when iterating over
-  // digits; this count gets reduced by elimination when counting
-  // successive digits
-  int kToFind = k;
-
-  // We start at the most significant digit in our radix, scanning
-  // through to the least significant digit
-#pragma unroll
-  for (int digitPos = sizeof(BitDataType) * 8 - RADIX_BITS;
-       digitPos >= 0;
-       digitPos -= RADIX_BITS) {
-
-    // Count radix distribution for the current position and reduce
-    // across all threads
-    countRadixUsingMask<DataType, BitDataType,
-                        IndexType, int, RadixConverter,
-                        RADIX_SIZE, RADIX_BITS>(
-                          conv, counts, smem,
-                          desired, desiredMask, digitPos,
-                          sliceSize, withinSliceStride, data);
-
-    // All threads participate in the comparisons below to know the
-    // final result
-
-#define CHECK_RADIX(i)                                                  \
-    int count = counts[i];                                              \
-                                                                        \
-    /* All threads have the same value in counts here, so all */        \
-    /* threads will return from the function. */                        \
-    if (count == 1 && kToFind == 1) {                                   \
-      /* There is a unique answer. */                                   \
-      desired = setBitfield(desired, i, digitPos, RADIX_BITS);          \
-      desiredMask =                                                     \
-        setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS);     \
-                                                                        \
-      /* The answer is now the unique element v such that: */           \
-      /* (v & desiredMask) == desired */                                \
-      /* However, we do not yet know what the actual element is. We */  \
-      /* need to perform a search through the data to find the */       \
-      /* element that matches this pattern. */                          \
-      *topK = findPattern<DataType, IndexType, RadixConverter>(         \
-        conv, (float*) smem, data, sliceSize,                           \
-        withinSliceStride, desired, desiredMask);                       \
-      return;                                                           \
-    }                                                                   \
-                                                                        \
-    if (count >= kToFind) {                                             \
-      desired = setBitfield(desired, i, digitPos, RADIX_BITS);          \
-      desiredMask =                                                     \
-        setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS);     \
-                                                                        \
-      /* The top-Kth element v must now be one such that: */            \
-      /* (v & desiredMask == desired) */                                \
-      /* but we haven't narrowed it down; we must check the next */     \
-      /* least-significant digit */                                     \
-      break;                                                            \
-    }                                                                   \
-                                                                        \
-    kToFind -= count                                                    \
-
-    if (Order) {
-      // Process in descending order
-#pragma unroll
-      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
-        CHECK_RADIX(i);
-      }
-    } else {
-      // Process in ascending order
-#pragma unroll
-      for (int i = 0; i < RADIX_SIZE; ++i) {
-        CHECK_RADIX(i);
-      }
-    }
-#undef CHECK_RADIX
-  } // end digitPos for
-
-  // There is no unique result, but there is a non-unique result
-  // matching `desired` exactly
-  *topK = conv.deconvert(desired);
-}
-
-template <typename IndexType, int Dim, bool Order>
-__global__ void gatherTopK(TensorInfo<float, IndexType> input,
-                           IndexType inputSliceSize,
-                           IndexType outputSliceSize, // aka `k`
-
-                           IndexType numInputSlices,
-                           IndexType inputWithinSliceStride,
-
-                           TensorInfo<float, IndexType> topK,
-                           IndexType numTopKSlices,
-                           IndexType topKWithinSliceStride,
-
-                           TensorInfo<long, IndexType> indices,
-                           IndexType indicesWithinSliceStride) {
-  // Indices are limited to integer fp precision, so counts can fit in
-  // int32, regardless of IndexType
-  __shared__ int smem[32]; // one per each warp, up to warp limit
-
-  IndexType slice = getLinearBlockId<IndexType>();
-  if (slice >= numInputSlices) {
-    return;
-  }
-
-  // Find the start offset for our slice
-  IndexType sliceStartIndex =
-    IndexToOffset<float, IndexType, Dim>::get(slice, input);
-  IndexType topKSliceStartIndex =
-    IndexToOffset<float, IndexType, Dim>::get(slice, topK);
-  IndexType indicesSliceStartIndex =
-    IndexToOffset<long, IndexType, Dim>::get(slice, indices);
-
-  float* inputSliceStart = &input.data[sliceStartIndex];
-  float* topKSliceStart = &topK.data[topKSliceStartIndex];
-  long* indicesSliceStart = &indices.data[indicesSliceStartIndex];
-
-  // Find the k-th highest element in our input
-  float topKValue = -1.0f;
-  radixSelect<float, unsigned int, IndexType, FloatToSortedInt, Order>(
-    FloatToSortedInt(),
-    inputSliceStart, outputSliceSize,
-    inputSliceSize, inputWithinSliceStride,
-    smem, &topKValue);
-
-  // Every value that is strictly less/greater than `pattern`
-  // (depending on sort dir) in sorted int format is in the top-K.
-  // The top-K value itself might not be unique.
-  //
-  // Since there are a variable number of elements that we see that
-  // are within the top-k, we don't know at what index to write out
-  // the resulting values.
-  // In order to get this, we perform an exclusive prefix sum of
-  // `hasTopK`. This will return the resulting index into which we
-  // need to write the result, if a thread has a result.
-
-  // All threads need to participate in the loop and the prefix sum,
-  // but not necessarily in the load; hence loop bounds being rounded
-  // up to a multiple of the block dim.
-  IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
-  IndexType writeIndexStart = 0;
-
-  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
-    bool inRange = (i < inputSliceSize);
-    float v =
-      inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
-    bool hasTopK;
-    if (Order) {
-      hasTopK = inRange && (v > topKValue);
-    } else {
-      hasTopK = inRange && (v < topKValue);
-    }
-
-    int index;
-    int carry;
-    exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
-
-    if (hasTopK) {
-      int writeIndex = writeIndexStart + index;
-      assert(writeIndex < outputSliceSize);
-
-      IndexType topKOffset = writeIndex * topKWithinSliceStride;
-      IndexType indexOffset = writeIndex * indicesWithinSliceStride;
-
-      topKSliceStart[topKOffset] = v;
-      indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
-    }
-
-    writeIndexStart += carry;
-  }
-
-  // We need to fill in the rest with actual == top-K values.
-  // The number that we need is outputSliceSize -
-  // writeIndexStart. There might be more than that number available,
-  // in which case we have to choose the first seen set. We do this
-  // via a prefix sum to calculate indices for writing results.
-  assert(outputSliceSize >= writeIndexStart);
-  IndexType topKRemaining = (outputSliceSize - writeIndexStart);
-
-  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
-    bool inRange = (i < inputSliceSize);
-    float v =
-      inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : 0.0f;
-    bool hasTopK = inRange && (v == topKValue);
-
-    int index;
-    int carry;
-    exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
-
-    if (hasTopK && index < topKRemaining) {
-      int writeIndex = writeIndexStart + index;
-      assert(writeIndex < outputSliceSize);
-
-      IndexType topKOffset = writeIndex * topKWithinSliceStride;
-      IndexType indexOffset = writeIndex * indicesWithinSliceStride;
-
-      topKSliceStart[topKOffset] = v;
-      indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
-    }
-
-    if (carry >= topKRemaining) {
-      break;
-    }
-
-    topKRemaining -= carry;
-    writeIndexStart += carry;
-  }
-}
-
-#undef RADIX_BITS
-#undef RADIX_SIZE
-#undef RADIX_MASK
-
-THC_API void THCudaTensor_topk(THCState* state,
-                               THCudaTensor *topK,
-                               THCudaLongTensor *indices,
-                               THCudaTensor *input,
-                               long k, int dim, int dir, int sorted) {
-  THAssert(topK != NULL && indices != NULL && input != NULL);
-  THCAssertSameGPU(THCudaTensor_checkGPU(state, 3, topK, indices, input));
-  THCCheckTensorDims(state, topK, 2);
-  long dims = THCudaLongTensor_nDimension(state, indices);
-  THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
-  THCCheckTensorDims(state, input, 2);
-
-  int numDims = THCudaTensor_nDimension(state, input);
-  THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range");
-
-  long sliceSize = THCudaTensor_size(state, input, dim);
-  THArgCheck(k > 0 && k <= sliceSize, 2, "k not in range for dimension");
-
-  // Build the output size, which is the dim being selected set to
-  // size k
-  THLongStorage* topKSize = THCudaTensor_newSizeOf(state, input);
-  THLongStorage_set(topKSize, dim, k);
-  THCudaTensor_resize(state, topK, topKSize, NULL);
-  THCudaLongTensor_resize(state, indices, topKSize, NULL);
-  THLongStorage_free(topKSize);
-
-#define RUN_K(INDEX_T, DIM, DIR)                                        \
-  gatherTopK<INDEX_T, DIM, DIR>                                         \
-    <<<grid, block, 0, THCState_getCurrentStream(state)>>>(             \
-      inputInfo,                                                        \
-      sliceSize,                                                        \
-      k,                                                                \
-      inputSlices,                                                      \
-      /* The actual dimension that the k-selection is running in */     \
-      /* may have changed from collapseDims() */                        \
-      inputInfo.strides[collapseInputDim],                              \
-      topKInfo,                                                         \
-      topKSlices,                                                       \
-      topKInfo.strides[collapseTopKDim],                                \
-      indicesInfo,                                                      \
-      indicesInfo.strides[collapseIndicesDim])
-
-#define RUN_DIR(INDEX_T, DIM)                   \
-  if (dir) {                                    \
-    RUN_K(INDEX_T, DIM, true);                  \
-  } else {                                      \
-    RUN_K(INDEX_T, DIM, false);                 \
-  }
-
-#define RUN_DIM(INDEX_T)                        \
-  if (allDims == 1) {                           \
-    RUN_DIR(INDEX_T, 1);                        \
-  } else if (allDims == 2) {                    \
-    RUN_DIR(INDEX_T, 2);                        \
-  } else if (allDims == 3) {                    \
-    RUN_DIR(INDEX_T, 3);                        \
-  } else {                                      \
-    RUN_DIR(INDEX_T, -1);                       \
-  }
-
-#define RUN_T(INDEX_T)                                                  \
-  TensorInfo<float, INDEX_T> inputInfo =                                \
-    getTensorInfo<THCudaTensor, INDEX_T>(state, input);                 \
-  TensorInfo<float, INDEX_T> topKInfo =                                 \
-    getTensorInfo<THCudaTensor, INDEX_T>(state, topK);                  \
-  TensorInfo<long, INDEX_T> indicesInfo =                               \
-    getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices);           \
-                                                                        \
-  /* We use these structures solely to find the offset to */            \
-  /* each slice we are operating on */                                  \
-  inputInfo.sizes[dim] = 1;                                             \
-  topKInfo.sizes[dim] = 1;                                              \
-  indicesInfo.sizes[dim] = 1;                                           \
-                                                                        \
-  /* Collapse all other dims */                                         \
-  int collapseInputDim = inputInfo.collapseDims(dim);                   \
-  int collapseTopKDim = topKInfo.collapseDims(dim);                     \
-  int collapseIndicesDim = indicesInfo.collapseDims(dim);               \
-                                                                        \
-  long inputSlices = 1;                                                 \
-  long topKSlices = 1;                                                  \
-  for (int i = 0; i < numDims; ++i) {                                   \
-    inputSlices *= inputInfo.sizes[i];                                  \
-    topKSlices *= topKInfo.sizes[i];                                    \
-  }                                                                     \
-                                                                        \
-  dim3 grid;                                                            \
-  if (!THC_getGridFromTiles(inputSlices, grid)) {                       \
-    THError("Slice to sort is too large");                              \
-  }                                                                     \
-                                                                        \
-  dim3 block(std::min(THCRoundUp(sliceSize, 32L), 1024L));              \
-                                                                        \
-  /* This is used as a template parameter to calculate indices. */      \
-  /* We only specialize it if all collapsed dim sizes are the */        \
-  /* same; otherwise, we use -1 which is the specialization */          \
-  /* parameter for arbitrary dimensions */                              \
-  int allDims = inputInfo.dims;                                         \
-  if (topKInfo.dims != allDims || indicesInfo.dims != allDims) {        \
-    allDims = -1;                                                       \
-  }                                                                     \
-                                                                        \
-  RUN_DIM(INDEX_T);
-
-  // Based on required index size, run the algorithm with the
-  // appropriate index type
-  if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input) &&
-      TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, topK) &&
-      TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, indices)) {
-    RUN_T(unsigned int);
-  } else {
-    RUN_T(unsigned long);
-  }
-#undef RUN_T
-#undef RUN_DIM
-#undef RUN_DIR
-#undef RUN_K
-
-  // Sort the results if the user wants them sorted, since our
-  // selection routine does not ensure sorting
-  if (sorted) {
-    // FIXME: the k/v inplace sort along slice only works for size <=
-    // 2048 at the moment
-    if (sliceSize <= 2048) {
-      // This avoids any memory allocations and performs all sorting
-      // work inplace along the slice
-      THCudaTensor_sortKeyValueInplace(state, topK, indices, dim, dir);
-    } else {
-      // Depend upon the backup sort that returns indices, which we
-      // can use in conjunction with gather to produce the original
-      // indices.
-      // This is not the most efficient implementation, especially since
-      // there are memory allocations performed here. If the user desires
-      // greater performance, they should torch.gather() the results
-      // themselves using the reported indices, providing previously
-      // allocated tensors to receive the results.
-      THCudaTensor* sortedTopK = THCudaTensor_new(state);
-      THCudaLongTensor* sortedIndices = THCudaLongTensor_new(state);
-      THCudaTensor_sort(state, sortedTopK, sortedIndices, topK, dim, dir);
-
-      THCudaLongTensor* sortedTopKIndices = THCudaLongTensor_new(state);
-
-      THCudaLongTensor_resizeAs(state, sortedTopKIndices, indices);
-      THCudaLongTensor_gather(state, sortedTopKIndices, indices, dim, sortedIndices);
-
-      THCudaTensor_freeCopyTo(state, sortedTopK, topK);
-      THCudaLongTensor_freeCopyTo(state, sortedTopKIndices, indices);
-      THCudaLongTensor_free(state, sortedIndices);
-    }
-  }
+#include "generic/THCTensorTopK.cu"
+#include "THCGenerateAllTypes.h"
 
-  THCudaCheck(cudaGetLastError());
-}
diff --git a/THCTensorTopK.cuh b/THCTensorTopK.cuh
new file mode 100644
index 0000000..32041e3
--- /dev/null
+++ b/THCTensorTopK.cuh
@@ -0,0 +1,473 @@
+#ifndef THC_TENSOR_TOPK_CUH
+#define THC_TENSOR_TOPK_CUH
+
+template <typename T>
+struct TopKTypeConfig {};
+
+template <>
+struct TopKTypeConfig<float> {
+  typedef unsigned int RadixType;
+
+  // Converts a float to an integer representation with the same
+  // sorting; i.e., for floats f1, f2:
+  // if f1 < f2 then convert(f1) < convert(f2)
+  // We use this to enable radix selection of floating-point values.
+  // This also gives a relative order for NaNs, but that's ok, as they
+  // will all be adjacent
+  static inline __device__ RadixType convert(float v) {
+    RadixType x = __float_as_int(v);
+    RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
+
+    return (x ^ mask);
+  }
+
+  static inline __device__ float deconvert(RadixType v) {
+    RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
+
+    return __int_as_float(v ^ mask);
+  }
+};
+
+template <>
+struct TopKTypeConfig<unsigned char> {
+  typedef unsigned int RadixType;
+
+  static inline __device__ RadixType convert(unsigned char v) {
+    return v;
+  }
+
+  static inline __device__ unsigned char deconvert(RadixType v) {
+    return v;
+  }
+};
+
+template <>
+struct TopKTypeConfig<char> {
+  typedef unsigned int RadixType;
+
+  static inline __device__ RadixType convert(char v) {
+    return 128u + v;
+  }
+
+  static inline __device__ char deconvert(RadixType v) {
+    return v - 128;
+  }
+};
+
+template <>
+struct TopKTypeConfig<short> {
+  typedef unsigned int RadixType;
+
+  static inline __device__ RadixType convert(short v) {
+    assert(sizeof(short) == 2);
+    return 32768u + v;
+  }
+
+  static inline __device__ short deconvert(RadixType v) {
+    return v - 32768;
+  }
+};
+
+template <>
+struct TopKTypeConfig<int> {
+  typedef unsigned int RadixType;
+
+  static inline __device__ RadixType convert(int v) {
+    assert(sizeof(int) == 4);
+    return 2147483648u + v;
+  }
+
+  static inline __device__ int deconvert(RadixType v) {
+    return v - 2147483648u;
+  }
+};
+
+template <>
+struct TopKTypeConfig<long> {
+  typedef unsigned long long int RadixType;
+
+  static inline __device__ RadixType convert(long v) {
+    assert(sizeof(long) == 8);
+    return 9223372036854775808ull + v;
+  }
+
+  static inline __device__ long deconvert(RadixType v) {
+    return v - 9223372036854775808ull;
+  }
+};
+
+template <>
+struct TopKTypeConfig<double> {
+  typedef unsigned long long int RadixType;
+
+  static inline __device__ RadixType convert(double v) {
+    RadixType x = __double_as_longlong(v);
+    RadixType mask = -((x >> 63)) | 0x8000000000000000;
+    return (x ^ mask);
+  }
+
+  static inline __device__ double deconvert(RadixType v) {
+    RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
+    return __longlong_as_double(v ^ mask);
+  }
+};
+
+template <>
+struct TopKTypeConfig<half> {
+  typedef unsigned int RadixType;
+
+  static inline __device__ RadixType convert(half v) {
+    RadixType x = __half_as_ushort(v);
+    RadixType mask = -((x >> 15)) | 0x8000;
+    return (x ^ mask);
+  }
+
+  static inline __device__ half deconvert(RadixType v) {
+    RadixType mask = ((v >> 15) - 1) | 0x8000;
+    return __ushort_as_half(v ^ mask);
+  }
+};
+
+// This function counts the distribution of all input values in a
+// slice we are selecting by radix digit at `radixDigitPos`, but only
+// those that pass the filter `((v & desiredMask) == desired)`.
+// This produces and broadcasts the seen counts for a single block only.
+// `smem` must have at least `RadixSize` elements.
+template <typename DataType, typename BitDataType,
+          typename IndexType, typename CountType,
+          int RadixSize, int RadixBits>
+__device__ void countRadixUsingMask(CountType counts[RadixSize],
+                                    CountType* smem,
+                                    BitDataType desired,
+                                    BitDataType desiredMask,
+                                    int radixDigitPos,
+                                    IndexType sliceSize,
+                                    IndexType withinSliceStride,
+                                    DataType* data) {
+  // Clear out per-thread counts from a previous round
+#pragma unroll
+  for (int i = 0; i < RadixSize; ++i) {
+    counts[i] = 0;
+  }
+
+  if (threadIdx.x < RadixSize) {
+    smem[threadIdx.x] = 0;
+  }
+  __syncthreads();
+
+  // Scan over all the data. Upon a read, the warp will accumulate
+  // counts per each digit in the radix using warp voting.
+  for (IndexType i = threadIdx.x; i < sliceSize; i += blockDim.x) {
+    BitDataType val = TopKTypeConfig<DataType>::convert(doLdg(&data[i * withinSliceStride]));
+
+    bool hasVal = ((val & desiredMask) == desired);
+    BitDataType digitInRadix = Bitfield<BitDataType>::getBitfield(val, radixDigitPos, RadixBits);
+
+#pragma unroll
+    for (unsigned int j = 0; j < RadixSize; ++j) {
+      bool vote = hasVal && (digitInRadix == j);
+      counts[j] += __popc(__ballot(vote));
+    }
+  }
+
+  // Now, for each warp, sum values
+  if (getLaneId() == 0) {
+#pragma unroll
+    for (unsigned int i = 0; i < RadixSize; ++i) {
+      atomicAdd(&smem[i], counts[i]);
+    }
+  }
+
+  __syncthreads();
+
+  // For each thread, read in the total counts
+#pragma unroll
+  for (unsigned int i = 0; i < RadixSize; ++i) {
+    counts[i] = smem[i];
+  }
+
+  __syncthreads();
+}
+
+// Over what radix we are selecting values
+#define RADIX_BITS 2 // digits are base-(2 ^ RADIX_BITS)
+#define RADIX_SIZE 4 // 2 ^ RADIX_BITS
+#define RADIX_MASK (RADIX_SIZE - 1)
+
+// This finds the unique value `v` that matches the pattern
+// ((v & desired) == desiredMask) in our sorted int format
+template <typename DataType, typename BitDataType, typename IndexType>
+__device__ DataType findPattern(DataType* smem,
+                             DataType* data,
+                             IndexType sliceSize,
+                             IndexType withinSliceStride,
+                             BitDataType desired,
+                             BitDataType desiredMask) {
+  if (threadIdx.x < 32) {
+    smem[threadIdx.x] = ScalarConvert<int, DataType>::to(0);
+  }
+  __syncthreads();
+
+  // All threads participate in the loop, in order to sync on the flag
+  IndexType numIterations = THCRoundUp(sliceSize, (IndexType) blockDim.x);
+  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+    bool inRange = (i < sliceSize);
+    DataType v = inRange ? doLdg(&data[i * withinSliceStride]) : ScalarConvert<int, DataType>::to(0);
+
+    if (inRange && ((TopKTypeConfig<DataType>::convert(v) & desiredMask) == desired)) {
+      // There should not be conflicts if we are using findPattern,
+      // since the result is unique
+      smem[0] = ScalarConvert<int, DataType>::to(1);
+      smem[1] = v; // can't use val as the flag, since it could be 0
+    }
+
+    __syncthreads();
+
+    DataType found = smem[0];
+    DataType val = smem[1];
+
+    __syncthreads();
+
+    // Check to see if a thread found the value
+    if (THCNumerics<DataType>::ne(found, ScalarConvert<int, DataType>::to(0))) {
+      // all threads return this value
+      return val;
+    }
+  }
+
+  // should not get here
+  assert(false);
+  return ScalarConvert<int, DataType>::to(0);
+}
+
+// Returns the top-Kth element found in the data using radix selection
+template <typename DataType, typename BitDataType, typename IndexType, bool Order>
+__device__ void radixSelect(DataType* data,
+                            IndexType k,
+                            IndexType sliceSize,
+                            IndexType withinSliceStride,
+                            int* smem,
+                            DataType* topK) {
+  // Per-thread buckets into which we accumulate digit counts in our
+  // radix
+  int counts[RADIX_SIZE];
+
+  // We only consider elements x such that (x & desiredMask) == desired
+  // Initially, we consider all elements of the array, so the above
+  // statement is true regardless of input.
+  BitDataType desired = 0;
+  BitDataType desiredMask = 0;
+
+  // We are looking for the top kToFind-th element when iterating over
+  // digits; this count gets reduced by elimination when counting
+  // successive digits
+  int kToFind = k;
+
+  // We start at the most significant digit in our radix, scanning
+  // through to the least significant digit
+#pragma unroll
+  for (int digitPos = sizeof(DataType) * 8 - RADIX_BITS;
+       digitPos >= 0;
+       digitPos -= RADIX_BITS) {
+
+    // Count radix distribution for the current position and reduce
+    // across all threads
+    countRadixUsingMask<DataType, BitDataType,
+                        IndexType, int,
+                        RADIX_SIZE, RADIX_BITS>(
+                          counts, smem,
+                          desired, desiredMask, digitPos,
+                          sliceSize, withinSliceStride, data);
+
+    // All threads participate in the comparisons below to know the
+    // final result
+
+
+#define CHECK_RADIX(i)                                                  \
+    int count = counts[i];                                              \
+                                                                        \
+    /* All threads have the same value in counts here, so all */        \
+    /* threads will return from the function. */                        \
+    if (count == 1 && kToFind == 1) {                                   \
+      /* There is a unique answer. */                                   \
+      desired = Bitfield<BitDataType>::setBitfield(desired, i, digitPos, RADIX_BITS);          \
+      desiredMask =                                                     \
+        Bitfield<BitDataType>::setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS);     \
+                                                                        \
+      /* The answer is now the unique element v such that: */           \
+      /* (v & desiredMask) == desired */                                \
+      /* However, we do not yet know what the actual element is. We */  \
+      /* need to perform a search through the data to find the */       \
+      /* element that matches this pattern. */                          \
+      *topK = findPattern<DataType, BitDataType, IndexType>(                         \
+        (DataType*) smem, data, sliceSize,                              \
+        withinSliceStride, desired, desiredMask);                       \
+      return;                                                           \
+    }                                                                   \
+                                                                        \
+    if (count >= kToFind) {                                             \
+      desired = Bitfield<BitDataType>::setBitfield(desired, i, digitPos, RADIX_BITS);          \
+      desiredMask =                                                     \
+        Bitfield<BitDataType>::setBitfield(desiredMask, RADIX_MASK, digitPos, RADIX_BITS);     \
+                                                                        \
+      /* The top-Kth element v must now be one such that: */            \
+      /* (v & desiredMask == desired) */                                \
+      /* but we haven't narrowed it down; we must check the next */     \
+      /* least-significant digit */                                     \
+      break;                                                            \
+    }                                                                   \
+                                                                        \
+    kToFind -= count                                                    \
+
+    if (Order) {
+      // Process in descending order
+#pragma unroll
+      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
+        CHECK_RADIX(i);
+      }
+    } else {
+      // Process in ascending order
+#pragma unroll
+      for (int i = 0; i < RADIX_SIZE; ++i) {
+        CHECK_RADIX(i);
+      }
+    }
+#undef CHECK_RADIX
+  } // end digitPos for
+
+  // There is no unique result, but there is a non-unique result
+  // matching `desired` exactly
+  *topK = TopKTypeConfig<DataType>::deconvert(desired);
+}
+
+template <typename T, typename IndexType, int Dim, bool Order>
+__global__ void gatherTopK(TensorInfo<T, IndexType> input,
+                           IndexType inputSliceSize,
+                           IndexType outputSliceSize, // aka `k`
+
+                           IndexType numInputSlices,
+                           IndexType inputWithinSliceStride,
+
+                           TensorInfo<T, IndexType> topK,
+                           IndexType numTopKSlices,
+                           IndexType topKWithinSliceStride,
+
+                           TensorInfo<long, IndexType> indices,
+                           IndexType indicesWithinSliceStride) {
+  // Indices are limited to integer fp precision, so counts can fit in
+  // int32, regardless of IndexType
+  __shared__ int smem[32]; // one per each warp, up to warp limit
+
+  IndexType slice = getLinearBlockId<IndexType>();
+  if (slice >= numInputSlices) {
+    return;
+  }
+
+  // Find the start offset for our slice
+  IndexType sliceStartIndex =
+    IndexToOffset<T, IndexType, Dim>::get(slice, input);
+  IndexType topKSliceStartIndex =
+    IndexToOffset<T, IndexType, Dim>::get(slice, topK);
+  IndexType indicesSliceStartIndex =
+    IndexToOffset<long, IndexType, Dim>::get(slice, indices);
+
+  T* inputSliceStart = &input.data[sliceStartIndex];
+  T* topKSliceStart = &topK.data[topKSliceStartIndex];
+  long* indicesSliceStart = &indices.data[indicesSliceStartIndex];
+
+  // Find the k-th highest element in our input
+  T topKValue = ScalarConvert<int, T>::to(0);
+  radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType, Order>(
+    inputSliceStart, outputSliceSize,
+    inputSliceSize, inputWithinSliceStride,
+    smem, &topKValue);
+
+  // Every value that is strictly less/greater than `pattern`
+  // (depending on sort dir) in sorted int format is in the top-K.
+  // The top-K value itself might not be unique.
+  //
+  // Since there are a variable number of elements that we see that
+  // are within the top-k, we don't know at what index to write out
+  // the resulting values.
+  // In order to get this, we perform an exclusive prefix sum of
+  // `hasTopK`. This will return the resulting index into which we
+  // need to write the result, if a thread has a result.
+
+  // All threads need to participate in the loop and the prefix sum,
+  // but not necessarily in the load; hence loop bounds being rounded
+  // up to a multiple of the block dim.
+  IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
+  IndexType writeIndexStart = 0;
+
+  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+    bool inRange = (i < inputSliceSize);
+    T v =
+      inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
+    bool hasTopK;
+    if (Order) {
+      hasTopK = inRange && (THCNumerics<T>::gt(v, topKValue));
+    } else {
+      hasTopK = inRange && (THCNumerics<T>::lt(v, topKValue));
+    }
+
+    int index;
+    int carry;
+    exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
+
+    if (hasTopK) {
+      int writeIndex = writeIndexStart + index;
+      assert(writeIndex < outputSliceSize);
+
+      IndexType topKOffset = writeIndex * topKWithinSliceStride;
+      IndexType indexOffset = writeIndex * indicesWithinSliceStride;
+
+      topKSliceStart[topKOffset] = v;
+      indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
+    }
+
+    writeIndexStart += carry;
+  }
+
+  // We need to fill in the rest with actual == top-K values.
+  // The number that we need is outputSliceSize -
+  // writeIndexStart. There might be more than that number available,
+  // in which case we have to choose the first seen set. We do this
+  // via a prefix sum to calculate indices for writing results.
+  assert(outputSliceSize >= writeIndexStart);
+  IndexType topKRemaining = (outputSliceSize - writeIndexStart);
+
+  for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
+    bool inRange = (i < inputSliceSize);
+    T v =
+      inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
+    bool hasTopK = inRange && (THCNumerics<T>::eq(v, topKValue));
+
+    int index;
+    int carry;
+    exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
+
+    if (hasTopK && index < topKRemaining) {
+      int writeIndex = writeIndexStart + index;
+      assert(writeIndex < outputSliceSize);
+
+      IndexType topKOffset = writeIndex * topKWithinSliceStride;
+      IndexType indexOffset = writeIndex * indicesWithinSliceStride;
+
+      topKSliceStart[topKOffset] = v;
+      indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
+    }
+
+    if (carry >= topKRemaining) {
+      break;
+    }
+
+    topKRemaining -= carry;
+    writeIndexStart += carry;
+  }
+}
+
+#undef RADIX_BITS
+#undef RADIX_SIZE
+#undef RADIX_MASK
+
+#endif // THC_TENSOR_TOPK_CUH
diff --git a/THCTensorTopK.h b/THCTensorTopK.h
deleted file mode 100644
index 711c047..0000000
--- a/THCTensorTopK.h
+++ /dev/null
@@ -1,14 +0,0 @@
-#ifndef TH_CUDA_TENSOR_TOPK_INC
-#define TH_CUDA_TENSOR_TOPK_INC
-
-#include "THCTensor.h"
-
-/* Returns the set of all kth smallest (or largest) elements, depending */
-/* on `dir` */
-THC_API void THCudaTensor_topk(THCState* state,
-                               THCudaTensor* topK,
-                               THCudaLongTensor* indices,
-                               THCudaTensor* input,
-                               long k, int dim, int dir, int sorted);
-
-#endif
diff --git a/generic/THCTensorTopK.cu b/generic/THCTensorTopK.cu
new file mode 100644
index 0000000..83ab1e1
--- /dev/null
+++ b/generic/THCTensorTopK.cu
@@ -0,0 +1,159 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorTopK.cu"
+#else
+
+THC_API void THCTensor_(topk)(THCState* state,
+                               THCTensor *topK,
+                               THCudaLongTensor *indices,
+                               THCTensor *input,
+                               long k, int dim, int dir, int sorted) {
+  THAssert(topK != NULL && indices != NULL && input != NULL);
+  THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input));
+  THArgCheck(THCTensor_(nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
+  long dims = THCudaLongTensor_nDimension(state, indices);
+  THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
+  int numDims = THCTensor_(nDimension)(state, input);
+  THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
+
+  THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range");
+
+  long sliceSize = THCTensor_(size)(state, input, dim);
+  THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension");
+
+  // Build the output size, which is the dim being selected set to
+  // size k
+  THLongStorage* topKSize = THCTensor_(newSizeOf)(state, input);
+  THLongStorage_set(topKSize, dim, k);
+  THCTensor_(resize)(state, topK, topKSize, NULL);
+  THCudaLongTensor_resize(state, indices, topKSize, NULL);
+  THLongStorage_free(topKSize);
+
+#define RUN_K(INDEX_T, DIM, DIR)                                        \
+  gatherTopK<real, INDEX_T, DIM, DIR>                                         \
+    <<<grid, block, 0, THCState_getCurrentStream(state)>>>(             \
+      inputInfo,                                                        \
+      sliceSize,                                                        \
+      k,                                                                \
+      inputSlices,                                                      \
+      /* The actual dimension that the k-selection is running in */     \
+      /* may have changed from collapseDims() */                        \
+      inputInfo.strides[collapseInputDim],                              \
+      topKInfo,                                                         \
+      topKSlices,                                                       \
+      topKInfo.strides[collapseTopKDim],                                \
+      indicesInfo,                                                      \
+      indicesInfo.strides[collapseIndicesDim])
+
+#define RUN_DIR(INDEX_T, DIM)                   \
+  if (dir) {                                    \
+    RUN_K(INDEX_T, DIM, true);                  \
+  } else {                                      \
+    RUN_K(INDEX_T, DIM, false);                 \
+  }
+
+#define RUN_DIM(INDEX_T)                        \
+  if (allDims == 1) {                           \
+    RUN_DIR(INDEX_T, 1);                        \
+  } else if (allDims == 2) {                    \
+    RUN_DIR(INDEX_T, 2);                        \
+  } else if (allDims == 3) {                    \
+    RUN_DIR(INDEX_T, 3);                        \
+  } else {                                      \
+    RUN_DIR(INDEX_T, -1);                       \
+  }
+
+#define RUN_T(INDEX_T)                                                  \
+  TensorInfo<real, INDEX_T> inputInfo =                                \
+    getTensorInfo<THCTensor, INDEX_T>(state, input);                 \
+  TensorInfo<real, INDEX_T> topKInfo =                                 \
+    getTensorInfo<THCTensor, INDEX_T>(state, topK);                  \
+  TensorInfo<long, INDEX_T> indicesInfo =                               \
+    getTensorInfo<THCudaLongTensor, INDEX_T>(state, indices);           \
+                                                                        \
+  /* We use these structures solely to find the offset to */            \
+  /* each slice we are operating on */                                  \
+  inputInfo.sizes[dim] = 1;                                             \
+  topKInfo.sizes[dim] = 1;                                              \
+  indicesInfo.sizes[dim] = 1;                                           \
+                                                                        \
+  /* Collapse all other dims */                                         \
+  int collapseInputDim = inputInfo.collapseDims(dim);                   \
+  int collapseTopKDim = topKInfo.collapseDims(dim);                     \
+  int collapseIndicesDim = indicesInfo.collapseDims(dim);               \
+                                                                        \
+  long inputSlices = 1;                                                 \
+  long topKSlices = 1;                                                  \
+  for (int i = 0; i < numDims; ++i) {                                   \
+    inputSlices *= inputInfo.sizes[i];                                  \
+    topKSlices *= topKInfo.sizes[i];                                    \
+  }                                                                     \
+                                                                        \
+  dim3 grid;                                                            \
+  if (!THC_getGridFromTiles(inputSlices, grid)) {                       \
+    THError("Slice to sort is too large");                              \
+  }                                                                     \
+                                                                        \
+  dim3 block(std::min(THCRoundUp(sliceSize, 32L), 1024L));              \
+                                                                        \
+  /* This is used as a template parameter to calculate indices. */      \
+  /* We only specialize it if all collapsed dim sizes are the */        \
+  /* same; otherwise, we use -1 which is the specialization */          \
+  /* parameter for arbitrary dimensions */                              \
+  int allDims = inputInfo.dims;                                         \
+  if (topKInfo.dims != allDims || indicesInfo.dims != allDims) {        \
+    allDims = -1;                                                       \
+  }                                                                     \
+                                                                        \
+  RUN_DIM(INDEX_T);
+
+  // Based on required index size, run the algorithm with the
+  // appropriate index type
+  if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, input) &&
+      TensorUtils<THCTensor>::canUse32BitIndexMath(state, topK) &&
+      TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, indices)) {
+    RUN_T(unsigned int);
+  } else {
+    RUN_T(unsigned long);
+  }
+#undef RUN_T
+#undef RUN_DIM
+#undef RUN_DIR
+#undef RUN_K
+
+  // Sort the results if the user wants them sorted, since our
+  // selection routine does not ensure sorting
+  if (sorted) {
+    // FIXME: the k/v inplace sort along slice only works for size <=
+    // 2048 at the moment
+    if (sliceSize <= 2048) {
+      // This avoids any memory allocations and performs all sorting
+      // work inplace along the slice
+      THCTensor_(sortKeyValueInplace)(state, topK, indices, dim, dir);
+    } else {
+      // Depend upon the backup sort that returns indices, which we
+      // can use in conjunction with gather to produce the original
+      // indices.
+      // This is not the most efficient implementation, especially since
+      // there are memory allocations performed here. If the user desires
+      // greater performance, they should torch.gather() the results
+      // themselves using the reported indices, providing previously
+      // allocated tensors to receive the results.
+      THCTensor* sortedTopK = THCTensor_(new)(state);
+      THCudaLongTensor* sortedIndices = THCudaLongTensor_new(state);
+      THCTensor_(sort)(state, sortedTopK, sortedIndices, topK, dim, dir);
+
+      THCudaLongTensor* sortedTopKIndices = THCudaLongTensor_new(state);
+
+      THCudaLongTensor_resizeAs(state, sortedTopKIndices, indices);
+      THCudaLongTensor_gather(state, sortedTopKIndices, indices, dim, sortedIndices);
+
+      THCTensor_(freeCopyTo)(state, sortedTopK, topK);
+      THCudaLongTensor_freeCopyTo(state, sortedTopKIndices, indices);
+      THCudaLongTensor_free(state, sortedIndices);
+    }
+  }
+
+  THCudaCheck(cudaGetLastError());
+}
+
+#endif // THC_GENERIC_FILE
diff --git a/generic/THCTensorTopK.h b/generic/THCTensorTopK.h
new file mode 100644
index 0000000..2c281b5
--- /dev/null
+++ b/generic/THCTensorTopK.h
@@ -0,0 +1,13 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorTopK.h"
+#else
+
+/* Returns the set of all kth smallest (or largest) elements, depending */
+/* on `dir` */
+THC_API void THCTensor_(topk)(THCState* state,
+                               THCTensor* topK,
+                               THCudaLongTensor* indices,
+                               THCTensor* input,
+                               long k, int dim, int dir, int sorted);
+
+#endif // THC_GENERIC_FILE