blob: 47fbe0801d7532ad0753a11a0491063a45fb8e43 [file] [log] [blame]
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/LookupTable.c"
#else
static void THNN_(LookupTable_resetCount)(THInteger_t *count_data, THIndexTensor *input)
{
int i;
THIndex_t *input_data = THIndexTensor_(data)(input);
long numel = THIndexTensor_(nElement)(input);
for (i = 0; i<numel; i++)
{
long k = input_data[i] - 1;
count_data[k] = 0;
}
for (i = 0; i<numel; i++)
{
long k = input_data[i] - 1;
count_data[k]++;
}
}
void THNN_(LookupTable_accGradParameters)(
THNNState *state,
THIndexTensor *input,
THTensor *gradOutput,
THTensor *gradWeight,
THIntegerTensor *count,
THTensor *sorted,
THTensor *indices,
bool scaleGradByFreq,
int paddingValue,
real scale)
{
long i;
THInteger_t *count_data = NULL;
if (scaleGradByFreq)
{
THIntegerTensor_(resize1d)(count, gradWeight->size[0]);
count_data = THIntegerTensor_(data)(count);
}
if (!THTensor_(isContiguous)(gradWeight))
THError("gradWeight must be contiguous");
if (!THIndexTensor_(isContiguous)(input))
THError("input must be contiguous");
if (THIndexTensor_(nDimension)(input) != 1 && THIndexTensor_(nDimension)(input) != 2)
THError("input must be a vector or matrix");
THIndex_t *input_data = THIndexTensor_(data)(input);
long numel = THIndexTensor_(nElement)(input);
long numw = THTensor_(size)(gradWeight, 0);
// check that inputs are all within range
for (i=0; i<numel; i++)
if (input_data[i] < 1 || input_data[i] > numw)
THError("input out of range");
gradOutput = THTensor_(newContiguous)(gradOutput);
real *gw = THTensor_(data)(gradWeight);
real *go = THTensor_(data)(gradOutput);
long stride = THTensor_(stride)(gradWeight, 0);
if (count_data)
THNN_(LookupTable_resetCount)(count_data, input);
#ifdef _OPENMP
if (numel > 1000)
{
// The strategy is to parallelize over sections of the vocabulary, so that
// thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread
// has to traverse the entire input, but the dominating factor is the axpy
// BLAS call.
#pragma omp parallel private(i)
{
int tid = omp_get_thread_num();
int nthreads = omp_get_num_threads();
long start = tid * (numw/nthreads + 1);
long end = start + (numw/nthreads + 1);
for (i=0; i<numel; i++)
{
if (input_data[i] != paddingValue)
{
long k = input_data[i] - 1;
if (k >= start && k < end)
{
real scale_ = scale;
if (count_data) scale_ /= count_data[k];
THBlas_(axpy)(stride, scale_, go + i*stride, 1, gw + k*stride, 1);
}
}
}
}
THTensor_(free)(gradOutput);
return;
}
#endif
for (i=0; i<numel; i++)
{
if (input_data[i] != paddingValue)
{
long k = input_data[i] - 1;
real scale_ = scale;
if (count_data) scale_ /= count_data[k];
THBlas_(axpy)(stride, scale_, go + i*stride, 1, gw + k*stride, 1);
}
}
THTensor_(free)(gradOutput);
}
#endif