blob: ec91d3f5f20881e1205f3811e2e4464d0b65dedb [file] [log] [blame]
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/IndexLinear.cu"
#else
static bool THCUNN_checkKeysValues(THCState *state, THCudaLongTensor* keys,
THCTensor* values)
{
return THCudaLongTensor_size(state, keys, 0) == THCTensor_(nElement)(state, values)
&& THCTensor_(nDimension)(state, values) == 1
&& THCudaLongTensor_nDimension(state, keys) == 1;
}
void THNN_(IndexLinear_updateOutput)(
THCState *state,
THCudaLongTensor *keys,
int64_t keysOffset,
THCTensor *values,
THCudaLongTensor *sizes,
THCudaLongTensor *cumSumSizes,
THCTensor *output,
THCTensor *weight,
THCTensor *bias,
THCTensor *normalizedValues,
int train)
{
// Make sure these inputs are contiguous to accelerate computations
THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
"keys vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, values), 3,
"values vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
"sizes vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
"cumSumSizes vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, output), 6,
"output vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
"weight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
"bias vector must be contiguous");
THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
"Keys and values should have the same number of elements");
int64_t batchSize = sizes->size[0];
int64_t outDim = bias->size[0];
int64_t wDim = weight->size[1];
int64_t weightStride = weight->stride[0];
int maxNormalize = wDim - outDim;
int64_t keysSize = keys->size[0];
int64_t nnzPerRow = divup(keysSize, batchSize);
THCTensor_(resize2d)(state, output, batchSize, outDim);
int64_t *keysData = THCudaLongTensor_data (state, keys);
real *valuesData = THCTensor_(data) (state, values);
int64_t *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
real *biasData = THCTensor_(data) (state, bias);
real *weightData = THCTensor_(data) (state, weight);
real *outData = THCTensor_(data) (state, output);
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 threads(THREADS_X, THREADS_Y);
int blocks_x = divup(outDim, threads.x);
int blocks_y = batchSize;
int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX);
int blocks_z = divup(nnzPerRow, nnzPerBlock);
dim3 blocks(blocks_x, blocks_y, blocks_z);
if (blocks_z > 1) {
THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream));
}
real *normalizedValuesData = NULL;
if (maxNormalize && train) {
THCTensor_(resize1d)(state, normalizedValues, keysSize);
normalizedValuesData = THCTensor_(data)(state, normalizedValues);
updateOutput<real, true><<<blocks, threads, 0, stream>>>
(outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
} else {
updateOutput<real, false><<<blocks, threads, 0, stream>>>
(outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
}
}
void THNN_(IndexLinear_accGradParameters)(
THCState *state,
THCudaLongTensor *keys,
int64_t keysOffset,
THCTensor *values,
THCudaLongTensor *sizes,
THCudaLongTensor *cumSumSizes,
THCTensor *gradOutput,
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
THCTensor* valuesBuffer,
accreal weightDecay,
accreal scale)
{
int64_t keysSize = keys->size[0];
int64_t batchSize = sizes->size[0];
int64_t outDim = bias->size[0];
int64_t wDim = weight->size[1];
int maxNormalize = wDim - outDim;
// Make sure these inputs are contiguous to accelerate computations
THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
"keys vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, values), 3,
"values vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
"sizes vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
"cumSumSizes vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
"gradOutput vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 7,
"gradWeight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, gradBias), 8,
"gradBias vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, weight), 9,
"weight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, bias), 10,
"bias vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, valuesBuffer), 11,
"valuesBuffer vector must be contiguous");
THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
"Keys and values should have the same number of elements");
THCTensor_(resize2d)(state, gradWeight, keysSize, outDim * (maxNormalize > 0 ? 2 : 1));
real *valuesData = THCTensor_(data) (state, values);
int64_t *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
real *gradOutputData = THCTensor_(data) (state, gradOutput);
real *gradBiasData = THCTensor_(data) (state, gradBias);
real *gradWeightData = THCTensor_(data) (state, gradWeight);
int64_t gradWeightStride = gradWeight->stride[0];
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 threads(THREADS_X, THREADS_Y);
int blocks_x = divup(outDim, threads.x);
accGradBias<real, false><<<blocks_x, threads, 0, stream>>>
(gradBiasData, gradOutputData, outDim, batchSize, scale, weightDecay);
dim3 blocks(blocks_x, batchSize);
accGradWeight<real><<<blocks, threads, 0, stream>>>
(gradWeightData, gradOutputData, valuesData, cumSumSizesData, outDim,
gradWeightStride, scale, weightDecay, maxNormalize);
}
void THNN_(IndexLinear_accUpdateGradParameters)(
THCState *state,
THCudaLongTensor *keys,
int64_t keysOffset,
THCTensor *values,
THCudaLongTensor *sizes,
THCudaLongTensor *cumSumSizes,
THCTensor *gradOutput,
THCTensor *weight,
THCTensor *bias,
accreal weightDecay,
accreal scale)
{
// Make sure these inputs are contiguous to accelerate computations
THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
"keys vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, values), 3,
"values vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
"sizes vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
"cumSumSizes vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
"gradOutput vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
"weight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
"bias vector must be contiguous");
THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
"Keys and values should have the same number of elements");
int64_t batchSize = sizes->size[0];
int64_t outDim = bias->size[0];
int64_t keysSize = keys->size[0];
int64_t wDim = weight->size[1];
int maxNormalize = wDim - outDim;
real *biasData = THCTensor_(data) (state, bias);
real *weightData = THCTensor_(data) (state, weight);
real *gradOutputData = THCTensor_(data) (state, gradOutput);
real *valuesData = THCTensor_(data) (state, values);
int64_t *keysData = THCudaLongTensor_data (state, keys);
int64_t *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
int64_t weightStride = weight->stride[0];
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 threads(THREADS_X, THREADS_Y);
int blocks_x = divup(outDim, threads.x);
accGradBias<real, true><<<blocks_x, threads, 0, stream>>>
(biasData, gradOutputData, outDim, batchSize, scale, weightDecay);
int64_t nnzPerRow = divup(keysSize, batchSize);
int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
dim3 blocks(blocks_x, blocks_y);
for (int64_t batchId = 0; batchId < batchSize; batchId++) {
accUpdateWeight<real><<<blocks, threads, 0, stream>>>
(weightData, weightStride, gradOutputData, outDim, valuesData,
cumSumSizesData, keysData, keysOffset, scale, weightDecay, maxNormalize,
batchId);
}
}
void THNN_(IndexLinear_updateParameters)(
THCState *state,
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
THCudaLongTensor *runningKeys,
THCudaLongTensor *cumSumSizes,
int64_t keysOffset,
accreal weightDecay,
accreal learningRate)
{
// Make sure these inputs are contiguous to accelerate computations
THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 1,
"gradWeight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, gradBias), 2,
"gradBias vector must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, weight), 3,
"weight matrix must be contiguous");
THArgCheck(THCTensor_(isContiguous)(state, bias), 4,
"bias vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, runningKeys), 5,
"runningKeys vector must be contiguous");
THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 6,
"cumSumSizes vector must be contiguous");
int64_t outDim = bias->size[0];
int64_t wDim = weight->size[1];
int maxNormalize = wDim - outDim;
int64_t keysSize = runningKeys->size[0];
int64_t batchSize = cumSumSizes->size[0];
THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
int64_t gradWeightStride = gradWeight->stride[0];
int64_t weightStride = weight->stride[0];
int64_t *keysData = THCudaLongTensor_data (state, runningKeys);
int64_t *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
real *gradWeightData = THCTensor_(data) (state, gradWeight);
real *weightData = THCTensor_(data) (state, weight);
dim3 threads(THREADS_X, THREADS_Y);
int64_t nnzPerRow = divup(keysSize, batchSize);
int blocks_x = divup(outDim, threads.x);
int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
dim3 blocks(blocks_x, blocks_y);
cudaStream_t stream = THCState_getCurrentStream(state);
for (int64_t batchId = 0; batchId < batchSize; batchId++) {
updateWeight<real><<<blocks, threads, 0, stream>>>
(weightData, gradWeightData, keysData, cumSumSizesData, outDim,
gradWeightStride, weightStride, keysOffset, learningRate, weightDecay,
maxNormalize, batchId);
}
}
#endif