blob: 003e960c3bf5c70093dc7dcf39db37f75e3cacaf [file] [log] [blame]
#ifndef THC_TENSOR_RANDOM_CUH
#define THC_TENSOR_RANDOM_CUH
#include "THCNumerics.cuh"
#include "THCReduceApplyUtils.cuh"
#include "THCTensorMathReduce.cuh"
#include <curand_kernel.h>
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
/* Separate kernel because curand_log_normal gets extra parameters. */
template <typename T>
__global__ void generateLogNormal(curandStateMtgp32 *state, int size, T *result, double mean, double stddev)
{
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
if (i < size) {
result[i] = ScalarConvert<float, T>::to(x);
}
}
}
template <>
__global__ void generateLogNormal<double>(curandStateMtgp32 *state, int size, double *result, double mean, double stddev)
{
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
double x = curand_log_normal_double(&state[blockIdx.x], mean, stddev);
if (i < size) {
result[i] = x;
}
}
}
#undef MAX_NUM_BLOCKS
#undef BLOCK_SIZE
// Normalizes the L1 norm of every row to 1; used by multinomial
template <typename T>
__global__ void renormRowsL1(T* dist, long rows, long cols) {
extern __shared__ __align__(sizeof(T)) unsigned char my_smem[];
T *smem = reinterpret_cast<T *>(my_smem);
for (long row = blockIdx.x; row < rows; row += gridDim.x) {
T sum = ScalarConvert<int, T>::to(0);
for (long col = threadIdx.x; col < cols; col += blockDim.x) {
sum = THCNumerics<T>::add(sum, dist[row * cols + col]);
}
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), ScalarConvert<int, T>::to(0));
if (threadIdx.x == 0) {
smem[0] = sum;
}
__syncthreads();
sum = smem[0];
if (THCNumerics<T>::gt(sum, ScalarConvert<int, T>::to(0))) {
for (long col = threadIdx.x; col < cols; col += blockDim.x) {
dist[row * cols + col] = THCNumerics<T>::div(dist[row * cols + col], sum);
}
}
}
}
template <typename T>
__global__ void
sampleMultinomialOnce(T* dest,
long distributions,
int categories,
T* dist) {
extern __shared__ __align__(sizeof(T)) unsigned char my_smem[];
T *smem = reinterpret_cast<T *>(my_smem);
T zero = ScalarConvert<int, T>::to(0);
for (long curDist = blockIdx.x;
curDist < distributions; curDist += gridDim.x) {
// Each block handles one distribution
// First pass, find the total sum of the distribution
T sum = zero;
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
sum = THCNumerics<T>::add(sum, dist[curDist * categories + cat]);
}
// threadIdx.x == 0 has the sum value from this
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), zero);
// Broadcast sum and sample value
if (threadIdx.x == 0) {
smem[0] = sum;
smem[1] = dest[curDist];
}
__syncthreads();
sum = smem[0];
T sample = smem[1];
__syncthreads();
if (THCNumerics<T>::eq(sum, zero) || THCNumerics<T>::eq(sample, zero)) {
// Choose the first element
if (threadIdx.x == 0) {
dest[curDist] = ScalarConvert<int, T>::to(1);
}
continue;
}
int chunks = THCCeilDiv(categories, (int) blockDim.x);
T prevHighProb = zero;
for (int chunk = 0; chunk < chunks; ++chunk) {
// All threads in bounds load a value
int cat = chunk * blockDim.x + threadIdx.x;
T val =
cat < categories ? THCNumerics<T>::div(dist[curDist * categories + cat], sum) :
zero;
smem[threadIdx.x] = val;
__syncthreads();
// Perform an inclusive prefix sum of the shared memory contents
for (int offset = 1; offset < blockDim.x; offset *= 2) {
T val = zero;
if (threadIdx.x >= offset) {
val = THCNumerics<T>::add(smem[threadIdx.x - offset], smem[threadIdx.x]);
}
__syncthreads();
if (threadIdx.x >= offset) {
smem[threadIdx.x] = val;
}
__syncthreads();
}
// Each thread will check to see if the sample falls in its
// bucket
T curBucket = THCNumerics<T>::add(smem[threadIdx.x], prevHighProb);
T prevBucket =
threadIdx.x == 0 ? prevHighProb :
THCNumerics<T>::add(smem[threadIdx.x - 1], prevHighProb);
bool inBucket =
(cat < categories) &&
(!THCNumerics<T>::gt(sample, curBucket)) &&
(THCNumerics<T>::gt(sample, prevBucket));
if (inBucket) {
// We're done; we have the sample
// Torch indices are 1-based
// FIXME: broadcast exit flag?
dest[curDist] = ScalarConvert<int, T>::to(cat + TH_INDEX_BASE);
}
// Store the previous scan's high value for future use
prevHighProb = THCNumerics<T>::add(prevHighProb, smem[blockDim.x - 1]);
__syncthreads();
}
}
}
template <typename T>
__device__ int binarySearchForMultinomial(T* dist,
int size,
T val) {
int start = 0;
int end = size;
while (end - start > 0) {
int mid = start + (end - start) / 2;
T midVal = dist[mid];
if (THCNumerics<T>::lt(midVal, val)) {
start = mid + 1;
} else {
end = mid;
}
}
if (start == size) {
// No probability mass or precision problems; just return the
// first element
start = 0;
}
return start;
}
template <typename T>
__global__ void
sampleMultinomialWithReplacement(curandStateMtgp32* state,
int totalSamples,
T* dest,
long distributions,
int categories,
T* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// The block determines the distribution for which we generate a point
for (long curDist = blockIdx.x;
curDist < distributions;
curDist += gridDim.x) {
for (int sampleBase = 0;
sampleBase < totalSamples; sampleBase += blockDim.y) {
// The warp determines the sample
int sample = sampleBase + threadIdx.y;
// All threads participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
if (threadIdx.x == 0 && sample < totalSamples) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>(
normDistPrefixSum + curDist * categories,
categories,
r);
// Torch indices are 1-based
dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE);
}
}
}
}
template <typename T>
__global__ void
sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
int totalSamples,
int sample,
T* dest,
long distributions,
int categories,
T* origDist,
T* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// The block and warp determines the distribution for which we
// generate a point
for (long curDistBase = blockIdx.x * blockDim.y;
curDistBase < distributions;
curDistBase += gridDim.x * blockDim.y) {
// The warp determines the distribution
long curDist = curDistBase + threadIdx.y;
// All threads must participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
if (threadIdx.x == 0 && curDist < distributions) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>(
normDistPrefixSum + curDist * categories,
categories,
r);
// Torch indices are 1-based
dest[curDist * totalSamples + sample] = ScalarConvert<int, T>::to(choice + TH_INDEX_BASE);
// Without replacement, so update the original probability so it
// is not considered a second time
origDist[curDist * categories + choice] = ScalarConvert<int, T>::to(0);
}
}
}
#endif // THC_TENSOR_RANDOM_CUH