blob: 77e06c8d7a37b7271a2dd758d7a1f98f11950947 [file] [log] [blame]
#include "THCTensorRandom.h"
#include "THCDeviceUtils.cuh"
#include "THCGeneral.h"
#include "THCTensorCopy.h"
#include "THCTensorMath.h"
#include "THCReduceApplyUtils.cuh"
#include <thrust/functional.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_mtgp32_host.h>
#include <curand_mtgp32dc_p_11213.h>
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
/* Sets up generator. Allocates but does not create the generator states. */
__host__ void initializeGenerator(THCState *state, Generator* gen)
{
THCudaCheck(THCudaMalloc(state, (void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
}
/* Frees memory allocated during setup. */
__host__ void destroyGenerator(THCState *state, Generator* gen)
{
if (gen->gen_states)
{
THCudaCheck(THCudaFree(state, gen->gen_states));
gen->gen_states = NULL;
}
if (gen->kernel_params)
{
THCudaCheck(THCudaFree(state, gen->kernel_params));
gen->kernel_params = NULL;
}
}
/* Creates a new generator state given the seed. */
__host__ void createGeneratorState(Generator* gen, unsigned long seed)
{
if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS)
{
THError("Creating MTGP constants failed.");
}
if (curandMakeMTGP32KernelState(gen->gen_states, mtgp32dc_params_fast_11213,
gen->kernel_params, MAX_NUM_BLOCKS, seed) != CURAND_STATUS_SUCCESS)
{
THError("Creating MTGP kernel state failed.");
}
}
/* Initialize generator array (must be called before any other function) */
__host__ void THCRandom_init(THCState* state, int devices, int current_device)
{
THCRNGState* rng_state = state->rngState;
rng_state->num_devices = devices;
rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator));
for (int i = 0; i < rng_state->num_devices; ++i)
{
rng_state->gen[i].initf = 0;
rng_state->gen[i].initial_seed = 0;
rng_state->gen[i].gen_states = NULL;
rng_state->gen[i].kernel_params = NULL;
}
rng_state->current_gen = &rng_state->gen[current_device];
// Initialize the generator for the current device. Other generators will be
// initialized on-demand in THCRandom_setGenerator.
initializeGenerator(state, rng_state->current_gen);
THCRandom_seed(state);
}
/* Destroy generators and free memory */
__host__ void THCRandom_shutdown(THCState* state)
{
THCRNGState* rng_state = state->rngState;
if (rng_state->gen == NULL) return;
for (int i = 0; i < rng_state->num_devices; ++i)
{
destroyGenerator(state, &rng_state->gen[i]);
}
free(rng_state->gen);
rng_state->gen = NULL;
rng_state->current_gen = NULL;
}
/* Set the generator for the current device */
__host__ void THCRandom_setGenerator(THCState* state, int device)
{
THCRNGState* rng_state = state->rngState;
if (device >= rng_state->num_devices) THError("Invalid device index.");
rng_state->current_gen = &rng_state->gen[device];
if (rng_state->current_gen->initf == 0)
{
initializeGenerator(state, rng_state->current_gen);
THCRandom_seed(state);
}
}
/* Random seed */
__host__ unsigned long THCRandom_seed(THCState* state)
{
unsigned long s = (unsigned long)time(0);
THCRandom_manualSeed(state, s);
return s;
}
__host__ unsigned long THCRandom_seedAll(THCState* state)
{
unsigned long s = (unsigned long)time(0);
THCRandom_manualSeedAll(state, s);
return s;
}
/* Manually set the seed */
__host__ void THCRandom_manualSeed(THCState* state, unsigned long seed)
{
THCRNGState* rng_state = state->rngState;
if (rng_state->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
rng_state->current_gen->initial_seed = seed;
createGeneratorState(rng_state->current_gen, seed);
rng_state->current_gen->initf = 1;
}
__host__ void THCRandom_manualSeedAll(THCState* state, unsigned long seed)
{
THCRNGState* rng_state = state->rngState;
int currentDevice;
THCudaCheck(cudaGetDevice(&currentDevice));
for (int i = 0; i < rng_state->num_devices; ++i) {
THCudaCheck(cudaSetDevice(i));
THCRandom_setGenerator(state, i);
THCRandom_manualSeed(state, seed);
}
THCudaCheck(cudaSetDevice(currentDevice));
THCRandom_setGenerator(state, currentDevice);
}
/* Get the initial seed */
__host__ unsigned long THCRandom_initialSeed(THCState* state)
{
return state->rngState->current_gen->initial_seed;
}
__host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state)
{
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
// The RNG state comprises the MTPG32 states and the seed.
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
static const size_t seed_size = sizeof(unsigned long);
static const size_t total_size = states_size + seed_size;
THByteTensor_resize1d(rng_state, total_size);
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), state->rngState->current_gen->gen_states,
states_size, cudaMemcpyDeviceToHost));
memcpy(THByteTensor_data(rng_state) + states_size, &state->rngState->current_gen->initial_seed, seed_size);
}
__global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_params *kernel)
{
state[threadIdx.x].k = kernel;
}
__host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
{
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
static const size_t seed_size = sizeof(unsigned long);
static const size_t total_size = states_size + seed_size;
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(state->rngState->current_gen->gen_states, THByteTensor_data(rng_state),
states_size, cudaMemcpyHostToDevice));
set_rngstate_kernel<<<1, MAX_NUM_BLOCKS, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, state->rngState->current_gen->kernel_params);
memcpy(&state->rngState->current_gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size);
}
#define GENERATE_KERNEL1(NAME, ARG1, CURAND_FUNC, TRANSFORM) \
__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
float x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
x = TRANSFORM; \
result[i] = x; \
} \
} \
}
#define GENERATE_KERNEL2(NAME, ARG1, ARG2, CURAND_FUNC, TRANSFORM) \
__global__ void NAME(curandStateMtgp32 *state, int size, float *result, ARG1, ARG2) \
{ \
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) { \
float x = CURAND_FUNC(&state[blockIdx.x]); \
if (i < size) { \
x = TRANSFORM; \
result[i] = x; \
} \
} \
}
GENERATE_KERNEL2(generate_uniform, double a, double b, curand_uniform, x * (b-a) + a)
GENERATE_KERNEL1(generate_bernoulli, double p, curand_uniform, (float)x <= p)
GENERATE_KERNEL2(generate_normal, double mean, double stdv, curand_normal, (x * stdv) + mean)
GENERATE_KERNEL1(generate_geometric, double p, curand_uniform, (log(1-x) / log(p)) + 1)
GENERATE_KERNEL1(generate_exponential, double lambda, curand_uniform, (float)(-1. / lambda * log(1-x)))
GENERATE_KERNEL2(generate_cauchy, double median, double sigma, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
#undef GENERATE_KERNEL1
#undef GENERATE_KERNEL2
/* Separate kernel because curand_log_normal gets extra parameters. */
__global__ void generate_log_normal(curandStateMtgp32 *state, int size, float *result, float mean, float stddev)
{
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;
for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {
float x = curand_log_normal(&state[blockIdx.x], mean, stddev);
if (i < size) {
result[i] = x;
}
}
}
#define NUM_BLOCKS min((int)THCCeilDiv(size, (long) BLOCK_SIZE), MAX_NUM_BLOCKS)
THC_API void THCudaTensor_uniform(THCState* state, THCudaTensor *self_, double a, double b)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, a, b);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_bernoulli(THCState* state, THCudaTensor *self_, double p)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_bernoulli<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, p);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_normal(THCState* state, THCudaTensor *self_, double mean, double stdv)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, mean, stdv);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_logNormal(THCState* state, THCudaTensor *self_, double mean, double stdv)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_log_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, mean, stdv);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_geometric(THCState* state, THCudaTensor *self_, double p)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, p);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_exponential(THCState* state, THCudaTensor *self_, double lambda)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, lambda);
THCudaTensor_freeCopyTo(state, self, self_);
};
THC_API void THCudaTensor_cauchy(THCState* state, THCudaTensor *self_, double median, double sigma)
{
THAssert(THCudaTensor_checkGPU(state, 1, self_));
if (state->rngState->current_gen == NULL)
{
THError("Random number generators have not been initialized.");
}
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
long size = THCudaTensor_nElement(state, self);
float *data = THCudaTensor_data(state, self);
generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states, size, data, median, sigma);
THCudaTensor_freeCopyTo(state, self, self_);
};
__device__ int binarySearchForMultinomial(float* dist,
int size,
float val) {
int start = 0;
int end = size;
while (end - start > 0) {
int mid = start + (end - start) / 2;
float midVal = dist[mid];
if (midVal < val) {
start = mid + 1;
} else {
end = mid;
}
}
if (start == size) {
// No probability mass or precision problems; just return the
// first element
start = 0;
}
return start;
}
// Normalizes the L1 norm of every row to 1; used by multinomial
__global__ void renormRowsL1(float* dist, long rows, long cols) {
extern __shared__ float smem[];
for (long row = blockIdx.x; row < rows; row += gridDim.x) {
float sum = 0.0f;
for (long col = threadIdx.x; col < cols; col += blockDim.x) {
sum += dist[row * cols + col];
}
sum = reduceBlock(smem, blockDim.x, sum, thrust::plus<float>(), 0.0f);
if (threadIdx.x == 0) {
smem[0] = sum;
}
__syncthreads();
sum = smem[0];
if (sum > 0.0f) {
for (long col = threadIdx.x; col < cols; col += blockDim.x) {
dist[row * cols + col] /= sum;
}
}
}
}
void THCudaTensor_renormRows(struct THCState* state,
THCudaTensor* t) {
THAssert(THCudaTensor_nDimension(state, t) == 2);
long rows = THCudaTensor_size(state, t, 0);
long cols = THCudaTensor_size(state, t, 1);
cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
THAssert(props != NULL);
int numSM = props->multiProcessorCount;
int maxThreads = props->maxThreadsPerBlock;
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
dim3 block(cols < maxThreads ? cols : maxThreads);
renormRowsL1
<<<grid, block, block.x * sizeof(float),
THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, t),
rows, cols);
}
__global__ void
sampleMultinomialOnce(float* dest,
long distributions,
int categories,
float* dist) {
extern __shared__ float smem[];
for (long curDist = blockIdx.x;
curDist < distributions; curDist += gridDim.x) {
// Each block handles one distribution
// First pass, find the total sum of the distribution
float sum = 0.0f;
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
sum += dist[curDist * categories + cat];
}
// threadIdx.x == 0 has the sum value from this
sum = reduceBlock(smem, blockDim.x, sum, thrust::plus<float>(), 0.0f);
// Broadcast sum and sample value
if (threadIdx.x == 0) {
smem[0] = sum;
smem[1] = dest[curDist];
}
__syncthreads();
sum = smem[0];
float sample = smem[1];
__syncthreads();
if (sum == 0.0f || sample == 0.0f) {
// Choose the first element
if (threadIdx.x == 0) {
dest[curDist] = 1;
}
continue;
}
int chunks = THCCeilDiv(categories, (int) blockDim.x);
float prevHighProb = 0.0f;
for (int chunk = 0; chunk < chunks; ++chunk) {
// All threads in bounds load a value
int cat = chunk * blockDim.x + threadIdx.x;
float val =
cat < categories ? dist[curDist * categories + cat] / sum : 0.0f;
smem[threadIdx.x] = val;
__syncthreads();
// Perform an inclusive prefix sum of the shared memory contents
for (int offset = 1; offset < blockDim.x; offset *= 2) {
float val = 0.0f;
if (threadIdx.x >= offset) {
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
}
__syncthreads();
if (threadIdx.x >= offset) {
smem[threadIdx.x] = val;
}
__syncthreads();
}
// Each thread will check to see if the sample falls in its
// bucket
float curBucket =
smem[threadIdx.x] + prevHighProb;
float prevBucket =
threadIdx.x == 0 ? prevHighProb : smem[threadIdx.x - 1] + prevHighProb;
bool inBucket =
(cat < categories) && (sample <= curBucket) && (sample > prevBucket);
if (inBucket) {
// We're done; we have the sample
// Torch indices are 1-based
// FIXME: broadcast exit flag?
dest[curDist] = cat + 1;
}
// Store the previous scan's high value for future use
prevHighProb += smem[blockDim.x - 1];
__syncthreads();
}
}
}
__global__ void
sampleMultinomialWithReplacement(curandStateMtgp32* state,
int totalSamples,
float* dest,
long distributions,
int categories,
float* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// The block determines the distribution for which we generate a point
for (long curDist = blockIdx.x;
curDist < distributions;
curDist += gridDim.x) {
for (int sampleBase = 0;
sampleBase < totalSamples; sampleBase += blockDim.y) {
// The warp determines the sample
int sample = sampleBase + threadIdx.y;
// All threads participate in this
float r = curand_uniform(&state[blockIdx.x]);
if (threadIdx.x == 0 && sample < totalSamples) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial(
normDistPrefixSum + curDist * categories,
categories,
r);
// Torch indices are 1-based
dest[curDist * totalSamples + sample] = (float) choice + 1.0f;
}
}
}
}
__global__ void
sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
int totalSamples,
int sample,
float* dest,
long distributions,
int categories,
float* origDist,
float* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// The block and warp determines the distribution for which we
// generate a point
for (long curDistBase = blockIdx.x * blockDim.y;
curDistBase < distributions;
curDistBase += gridDim.x * blockDim.y) {
// The warp determines the distribution
long curDist = curDistBase + threadIdx.y;
// All threads must participate in this
float r = curand_uniform(&state[blockIdx.x]);
if (threadIdx.x == 0 && curDist < distributions) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial(
normDistPrefixSum + curDist * categories,
categories,
r);
// Torch indices are 1-based
dest[curDist * totalSamples + sample] = (float) choice + 1.0f;
// Without replacement, so update the original probability so it
// is not considered a second time
origDist[curDist * categories + choice] = 0.0f;
}
}
}
THC_API void THCudaTensor_multinomial(struct THCState *state,
THCudaTensor *self,
THCudaTensor *prob_dist,
int n_sample,
int with_replacement)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, prob_dist));
if (state->rngState->current_gen == NULL) {
THError("Random number generators have not been initialized.");
}
int inputSize = THCudaTensor_nDimension(state, prob_dist);
THArgCheck(inputSize > 0 && inputSize <= 2, 2,
"prob_dist must be 1 or 2 dim");
// Categories are in the innermost dimension
long numDist =
inputSize == 1 ? 1 : THCudaTensor_size(state, prob_dist, 0);
long numCategoriesLong =
inputSize == 1 ? THCudaTensor_size(state, prob_dist, 0) :
THCudaTensor_size(state, prob_dist, 1);
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2,
"number of categories cannot exceed 2^24");
int numCategories = (int) numCategoriesLong;
THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples");
if (!with_replacement) {
THArgCheck(n_sample <= numCategories, 2,
"cannot sample n_sample > prob_dist:size(1) samples without "
"replacement");
}
// It is possible that prob_dist is non-contiguous
THCudaTensor* probDistContig =
THCudaTensor_newContiguous(state, prob_dist);
// Restructure data for 2d
if (inputSize == 1) {
THCudaTensor_resize2d(state, probDistContig, 1, numCategories);
}
THCudaTensor_resize2d(state, self, numDist, n_sample);
if (n_sample == 1) {
// Optimized allocation-free implementation
// To exploit greater parallelism for the sampling, generate the
// Uniform random samples in a separate kernel launch, into the
// result memory. The device RNG is thread-limited
THCudaTensor_uniform(state, self, 0.0, 1.0);
cudaDeviceProp* props = THCState_getCurrentDeviceProperties(state);
THAssert(props != NULL);
int numSM = props->multiProcessorCount;
int maxThreads = props->maxThreadsPerBlock;
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
sampleMultinomialOnce
<<<grid, block, block.x * sizeof(float),
THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, self),
numDist,
numCategories,
THCudaTensor_data(state, probDistContig));
} else {
// Generic, slow implementation with memory allocations
// For sampling without replacement, we modify the distribution
// for subsequent samples in this space
THCudaTensor* origDist = THCudaTensor_new(state);
THCudaTensor_resizeAs(state, origDist, probDistContig);
THCudaTensor_copy(state, origDist, probDistContig);
THCudaTensor* normDist = THCudaTensor_new(state);
THCudaTensor_resizeAs(state, normDist, probDistContig);
THCudaTensor* prefixSum = THCudaTensor_new(state);
// Renorm along rows
THCudaTensor_copy(state, normDist, origDist);
THCudaTensor_renormRows(state, normDist);
// Prefix sum along rows
THCudaTensor_cumsum(state, prefixSum, normDist, 1);
if (with_replacement) {
// Sample with replacement
// Binary search is warp divergent (so effectively we're running
// with just a single thread), but for better utilization,
// we need each block to have at least 4 warps.
dim3 block(32, 4);
// Each warp in a block will generate a sample from one
// distribution concurrently.
dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
sampleMultinomialWithReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states,
n_sample,
THCudaTensor_data(state, self),
numDist, numCategories,
THCudaTensor_data(state, prefixSum));
} else {
// Sample without replacement
// Binary search is warp divergent (so effectively we're running
// with just a single thread), but for better utilization,
// we need each block to have at least 4 warps.
dim3 block(32, 4);
// Each warp in a block will generate a sample from a different
// distribution concurrently.
long numBlocks = THCCeilDiv(numDist, 4L);
dim3 grid(numBlocks < MAX_NUM_BLOCKS ? numBlocks : MAX_NUM_BLOCKS);
for (int sample = 0; sample < n_sample; ++sample) {
if (sample > 0) {
// Update probabilities
// Renorm along rows
THCudaTensor_copy(state, normDist, origDist);
THCudaTensor_renormRows(state, normDist);
// Prefix sum along rows
THCudaTensor_cumsum(state, prefixSum, normDist, 1);
}
// The kernel can only draw one sample before we have to
// recalculate our distribution
sampleMultinomialWithoutReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
state->rngState->current_gen->gen_states,
n_sample,
sample,
THCudaTensor_data(state, self),
numDist, numCategories,
THCudaTensor_data(state, origDist),
THCudaTensor_data(state, prefixSum));
}
}
THCudaTensor_free(state, prefixSum);
THCudaTensor_free(state, normDist);
THCudaTensor_free(state, origDist);
}
// Revert data restructuring based on input sizes
if (inputSize == 1) {
THCudaTensor_resize1d(state, self, n_sample);
// Unfortunately, if prob_dist is contiguous already,
// newContiguous is not a private copy, so we have to restructure
// this too, so as to not affect prob_dist
THCudaTensor_resize1d(state, probDistContig, numCategories);
}
THCudaTensor_free(state, probDistContig);
}
#undef NUM_BLOCKS