adding paddingValue to LookupTable
diff --git a/generic/LookupTable.c b/generic/LookupTable.c
index ed9656e..47fbe08 100644
--- a/generic/LookupTable.c
+++ b/generic/LookupTable.c
@@ -29,11 +29,12 @@
THTensor *sorted,
THTensor *indices,
bool scaleGradByFreq,
+ int paddingValue,
real scale)
{
long i;
THInteger_t *count_data = NULL;
-
+
if (scaleGradByFreq)
{
THIntegerTensor_(resize1d)(count, gradWeight->size[0]);
@@ -81,13 +82,15 @@
long end = start + (numw/nthreads + 1);
for (i=0; i<numel; i++)
{
- long k = input_data[i] - 1;
- if (k >= start && k < end)
+ if (input_data[i] != paddingValue)
{
- real lr = scale;
- if (count_data)
- lr /= count_data[k];
- THBlas_(axpy)(stride, lr, go + i*stride, 1, gw + k*stride, 1);
+ long k = input_data[i] - 1;
+ if (k >= start && k < end)
+ {
+ real scale_ = scale;
+ if (count_data) scale_ /= count_data[k];
+ THBlas_(axpy)(stride, scale_, go + i*stride, 1, gw + k*stride, 1);
+ }
}
}
}
@@ -99,11 +102,13 @@
for (i=0; i<numel; i++)
{
- long k = input_data[i] - 1;
- real lr = scale;
- if (count_data)
- lr /= count_data[k];
- THBlas_(axpy)(stride, lr, go + i*stride, 1, gw + k*stride, 1);
+ if (input_data[i] != paddingValue)
+ {
+ long k = input_data[i] - 1;
+ real scale_ = scale;
+ if (count_data) scale_ /= count_data[k];
+ THBlas_(axpy)(stride, scale_, go + i*stride, 1, gw + k*stride, 1);
+ }
}
THTensor_(free)(gradOutput);
diff --git a/generic/THNN.h b/generic/THNN.h
index 811b39c..e758235 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -150,6 +150,7 @@
THTensor *sorted,
THTensor *indices,
bool scaleGradByFreq,
+ int paddingValue,
real scale);
TH_API void THNN_(MarginCriterion_updateOutput)(