blob: 0158f254a2d0141e5731c2727020043aa527f0af [file] [log] [blame]
#ifndef THC_TENSOR_MODE_CUH
#define THC_TENSOR_MODE_CUH
#include "THCNumerics.cuh"
#include "THCSortUtils.cuh"
#include "THCScanUtils.cuh"
struct ThrustHalfLess
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::lt(lhs, rhs);
}
};
struct ThrustHalfNotEqualTo
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::ne(lhs, rhs);
}
};
struct ThrustHalfEqualTo
{
__host__ __device__ inline bool operator()(const half& lhs, const half& rhs) {
return THCNumerics<half>::eq(lhs, rhs);
}
};
struct ThrustHalfEqualToPredicate
{
ThrustHalfEqualToPredicate(half val): val_(val) {}
__host__ __device__ inline bool operator()(half x) {
return THCNumerics<half>::eq(val_, x);
}
half val_;
};
template <typename T>
struct BinaryAddOp {
__host__ __device__ inline T operator()(const T a, const T b) {
return THCNumerics<T>::add(a, b);
}
};
template <>
struct BinaryAddOp<unsigned int> {
__host__ __device__ inline unsigned int operator()(const unsigned int a, const unsigned int b) {
return a + b;
}
};
// Used for a segmented reduction
struct ModeUnsignedBoolPair {
unsigned int val;
bool flag;
};
// In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int)
// pairs of data
struct ModeUnsignedPair {
unsigned int val;
unsigned int index;
};
template <typename T>
struct MaxReduceOp {
__host__ __device__ inline T operator()(const T& a, const T& b) {
return b.val > a.val ? b : a;
}
};
template <typename T>
struct MatchReduceOp {
__host__ __device__ inline T operator()(const T& a, const T& b) {
return b.flag ? b : a;
}
};
// The mode kernel has the following characteristics: It uses internal shared memory
// buffers of Power2Size, which must be greater than the number of elements. Additionally,
// there is one block for every slice to calculate the mode for, and in each block there
// is one thread for every two elements.
//
// Both sorted and positions are assumed to be contiguous Tensors with the mode dimension
// as the innermost dim, such that we can get the particular slice for a Tensor via its
// linear block dimension * the slice size.
template <typename T, unsigned int Power2Size>
__global__ void computeMode(
T *input,
TensorInfo<T, unsigned int> values,
TensorInfo<int64_t, unsigned int> indices,
int64_t sliceSize)
{
int tidx = threadIdx.x;
int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
// First, we need to calculate the offset into the sorted Tensor that represents
// the start of the slice for this block to calculate the mode for. This offset
// is a combination of the gridIndices, and the number of elements in the slice.
unsigned int blockId = getLinearBlockId<unsigned int>();
unsigned int linearOffset = blockId * sliceSize;
// shmem is a dynamically sized buffer we will use throughout the kernel to
// handle computation efficiently. The size of this shmem must be
// sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
//
// Initially, the buffer will be organized as follows:
//
// [smem (slice elements) | bmem (valid indices) | <scratch space>]
extern __shared__ char shmem[];
// smem represents a proportion of the shared memory buffer that is used to store
// the elements from the slice:
T *smem = reinterpret_cast<T *>(shmem);
// Each thread loads up to two elements from the Tensor into shared memory
if (tidx < sliceSize) {
smem[tidx] = input[linearOffset + tidx];
}
if (stidx < sliceSize) {
smem[stidx] = input[linearOffset + stidx];
}
// Next, we initialize a boolean region of the buffer, offset by the loaded element
// smem region
bool *bmem = reinterpret_cast<bool *>(&smem[Power2Size]);
// The first use of this region stores bmem[i] = i < sliceSize to mark the valid
// components in the smem buffer
bmem[tidx] = tidx < sliceSize;
bmem[stidx] = stidx < sliceSize;
__syncthreads(); // barrier for smem, bmem initialization
// First, sort the input slice in ascending order. smem contains the input
// elements, and bmem marks the valid indices
bitonicSortKeys<LTComp<T>, T, unsigned int, Power2Size>(smem, bmem, LTComp<T>());
__syncthreads(); // make no assumptions that the sort syncs at end
// The next step of our algorithm is performing a block-wide comparison of
// neighboring elements. In particular, given an sorted input slice A, we
// produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise 0.
//
// Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
// B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
//
// In particular, we can think of B[i] true indicating the start of a sequence of
// equal values in the sorted list. Similarly, we will also store the negation of B,
// which we'll call C. In particular, we can think of C[i] = true iff A[i-1] == A[i]
// in our original sorted slice.
//
// C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
// We overwrite bmem, and treat the rest of shared memory as a buffer of (index, flag) pairs
// where the index represents values from C, and the flag represents values from B.
//
// [smem (sorted slice) | ubpmem (index, flag pairs)]
struct ModeUnsignedBoolPair *ubpmem = reinterpret_cast<struct ModeUnsignedBoolPair *>(
&smem[Power2Size]);
if (tidx == 0) {
ubpmem[0].flag = true;
ubpmem[0].val = 0;
}
// Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
ubpmem[tidx * 2 + 1].flag = THCNumerics<T>::ne(smem[tidx * 2], smem[tidx * 2 + 1]); // (0, 1), (1, 2), etc.
ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
// Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
if (((tidx + 1) * 2) < Power2Size) {
ubpmem[(tidx + 1) * 2].flag = THCNumerics<T>::ne(smem[((tidx + 1) * 2) - 1], smem[(tidx + 1) * 2]);
ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
}
__syncthreads(); // barrier for ubpmem initialization
// Next, we perform a segmented prefix sum on the neighboring elements, where
// the presence of a one indicates the start of a segment. In this case B acts
// as the segment start flags, and C is the buffer to be summed:
//
// Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
// Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
// Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
//
// Afterwards, the (index) components of the ubpmem buffer contain the lengths of the
// segments (minus 1), i.e. the counts of each element in the original input.
inclusivePrefixScan<
struct ModeUnsignedBoolPair,
struct SegmentedScanOp<struct ModeUnsignedBoolPair, BinaryAddOp<unsigned int> >,
Power2Size>(
ubpmem,
SegmentedScanOp<struct ModeUnsignedBoolPair, BinaryAddOp<unsigned int> >(BinaryAddOp<unsigned int>()));
// assumes scan syncs at the end
// Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e. we treat the
// boolean flag regions as integers). We initialize these to represent indices, and we'll call
// this buffer I
struct ModeUnsignedPair *uupmem = reinterpret_cast<struct ModeUnsignedPair *>(ubpmem);
// At this point, we need to find the maximum element in lengths buffer C.
// This element will represent the count (-1) of the mode. Because of the
// way we have set up the problem, the index where this mode occurs will
// also be the location of the mode value in the sorted array, e.g.
//
// smem = [0, 0, 1, 1, 1, 2]
// C = [0, 1, 0, 1, 2, 0]
// I = [0, 1, 2, 3, 4, 5]
// ^
// maximum value, also aligned with mode = 1
//
// We perform a block wide max-reduction of the C buffer, but we also need the
// indices to come along with it, so we utilize the uupmem construction.
//
// At the end we need to return the ModeUnsignedPair containing index = 4, val = 2,
// which represents the max
// In practice, we will make each thread locally reduce 2 values in its registers prior
// to the global block-wide reduction. Note that instead of tidx/stidx, we utilize tidx * 2,
// tidx * 2 + 1, so each thread deals with adjacent elements. This is because the reduce
// code below relies on thread elements to be adjacent.
struct ModeUnsignedPair uup[2];
uup[0].index = tidx * 2;
uup[0].val = ubpmem[tidx * 2].val;
uup[1].index = tidx * 2 + 1;
uup[1].val = ubpmem[tidx * 2 + 1].val;
__syncthreads();
struct ModeUnsignedPair max = {0, 0};
max = reduceBlockWithNThreadLocalReductions<struct ModeUnsignedPair, MaxReduceOp<struct ModeUnsignedPair>, 2>
(uupmem, uup, sliceSize, MaxReduceOp<struct ModeUnsignedPair>(), max);
// Store the mode in shared memory for use in finding the mode in the input slice
__shared__ T mode;
// Given the above constraints, the mode is the value at the reduced index in the
// original sorted element buffer
if (tidx == 0) {
mode = smem[max.index];
}
__syncthreads(); // broadcast mode
// Finally, we need to find the "an" index of the mode in the input Tensor. The API does
// not constrain which index we pick, so it can be any of the indices that contain the mode.
// We will do a reduction to find the index. We go back to using the (index, flag) buffer
// arrangement. First, we mark indices that are equal to the mode, i.e B[i] = true if
// input[i] == mode, and initialize C[i] to be the index
//
// Again we reduce 2 elements in the thread's registers prior to the block-wide reduction
struct ModeUnsignedBoolPair ubpp[2];
if (tidx * 2 < sliceSize) {
ubpp[0].flag = THCNumerics<T>::eq(input[linearOffset + (tidx * 2)], mode);
ubpp[0].val = tidx * 2;
}
if (tidx * 2 + 1 < sliceSize) {
ubpp[1].flag = THCNumerics<T>::eq(input[linearOffset + (tidx * 2 + 1)], mode);
ubpp[1].val = tidx * 2 + 1;
}
// Then we perform a similar reduction to the one above, except this time we update
// the element if the element at the base position is not equal to the mode and
// the element at the offset position is. At the end, C[0] will contain an index
// with the mode.
struct ModeUnsignedBoolPair match = {0, false};
match = reduceBlockWithNThreadLocalReductions<struct ModeUnsignedBoolPair, MatchReduceOp<struct ModeUnsignedBoolPair>, 2>
(ubpmem, ubpp, sliceSize, MatchReduceOp<struct ModeUnsignedBoolPair>(), match);
// Finally, we have the mode, and an index where it occurs. We use a single thread
// to place this in the appropriate output position
if (tidx == 0) {
int64_t index = TH_INDEX_BASE + match.val;
unsigned int outputOffset = IndexToOffset<T, unsigned int, -1>::get(blockId, values);
values.data[outputOffset] = mode;
indices.data[outputOffset] = index;
}
}
#endif // THC_TENSOR_MODE_CUH