Adding table input support for batched SparseLinear, implementing gradInput correctly, fixing other bugs
diff --git a/generic/SparseLinear.c b/generic/SparseLinear.c
index a84e030..2e24d91 100644
--- a/generic/SparseLinear.c
+++ b/generic/SparseLinear.c
@@ -5,15 +5,21 @@
#ifdef _OPENMP
#include <omp.h>
#endif
+#include <stdio.h>
#define ROW_PTR2(t, r) (THTensor_(data)(t) + (r) * (t)->stride[0])
#define COL_PTR2(t, c) (THTensor_(data)(t) + (c) * (t)->stride[1])
-static bool THNN_(checkInput)(THTensor* t)
+static bool THNN_(checkLegacyInput)(THTensor* t)
{
return t->nDimension == 3 && t->size[2] == 2;
}
+static bool THNN_(checkInput)(THTensor* t)
+{
+ return t->nDimension == 2 && t->size[1] == 3;
+}
+
static bool THNN_(checkSize2D)(THTensor* t, long size0, long size1)
{
return t->nDimension == 2 && t->size[0] == size0 && t->size[1] == size1;
@@ -41,15 +47,61 @@
THTensor *input,
THTensor *output,
THTensor *weight,
- THTensor *bias,
- THTensor *cudaBuffer,
- THTensor *shardBuffer)
+ THTensor *bias)
+{
+ long h, i;
+ long outDim = THTensor_(size)(weight, 0);
+ long inDim = THTensor_(size)(weight, 1);
+ long batchSize = THTensor_(size)(output, 0);
+
+ THArgCheck(THNN_(checkInput)(input), 2, "input must be in coo format, nnz x 3");
+ THArgCheck(THTensor_(isContiguous)(output), 3, "output must be contiguous");
+ THArgCheck(THNN_(checkSize1D)(bias, outDim), 5, "bias size wrong");
+
+ long nnz = THTensor_(size)(input, 0);
+
+ // output = weight * input + bias
+ THTensor_(zero)(output);
+#pragma omp parallel for private(i) schedule(static) if (nnz * outDim > 10000)
+ for (i = 0; i < nnz; i++) {
+ real val = THNN_(get2d)(input, i, 2);
+ if (val == 0) {
+ continue;
+ }
+
+ long offset = (long)(THNN_(get2d)(input, i, 1)) - 1;
+ long h = (long)(THNN_(get2d)(input, i, 0)) - 1;
+ if (offset >= 0 && offset < inDim) {
+ THBlas_(axpy)(outDim,
+ val,
+ COL_PTR2(weight, offset), weight->stride[0],
+ ROW_PTR2(output, h), output->stride[1]);
+ } else {
+ THError("index out of bound. updateOutput: %d not between 1 and %d",
+ offset + 1, inDim);
+ }
+ }
+
+ THTensor* output_row = THTensor_(new)();
+ for (h = 0; h < batchSize; h++) {
+ THTensor_(select)(output_row, output, 0, h);
+ THTensor_(cadd)(output_row, bias, 1.0, output_row);
+ }
+ THTensor_(free)(output_row);
+}
+
+void THNN_(SparseLinear_legacyUpdateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ THTensor *weight,
+ THTensor *bias)
{
long h, i;
long outDim = THTensor_(size)(weight, 0);
long inDim = THTensor_(size)(weight, 1);
- THArgCheck(THNN_(checkInput)(input), 2, "input size must be batchsize x nnz x 2");
+ THArgCheck(THNN_(checkLegacyInput)(input), 2, "input size must be batchsize x nnz x 2");
THArgCheck(THTensor_(isContiguous)(output), 3, "output must be contiguous");
THArgCheck(THNN_(checkSize1D)(bias, outDim), 5, "bias size wrong");
@@ -105,6 +157,65 @@
long inDim = THTensor_(size)(weight, 1);
THArgCheck(THNN_(checkInput)(input), 2,
+ "input must be in coo format, nnz x 3");
+ THArgCheck(THNN_(checkSize2D)(gradWeight, outDim, inDim), 4,
+ "gradWeight size wrong");
+ THArgCheck(THNN_(checkSize1D)(gradBias, outDim), 5,
+ "gradBias size wrong");
+ THArgCheck(THTensor_(isContiguous)(gradOutput), 1,
+ "gradOutput must be contiguous");
+
+ long nnz = THTensor_(size)(input, 0);
+ // THTensor_(resize2d)(gradOutput, batchSize, outDim);
+
+ // gradWeight += gradOutput * input
+#pragma omp parallel for private(h, i) schedule(static) if (\
+ nnz * outDim > 10000)
+ for (i = 0; i < nnz; i++) {
+ real val = scale * THNN_(get2d)(input, i, 2);
+
+ long offset = (long)(THNN_(get2d)(input, i, 1)) - 1;
+ long h = (long)(THNN_(get2d)(input, i, 0)) - 1;
+ if (offset >= 0 && offset < inDim) {
+ THBlas_(axpy)(outDim,
+ val,
+ ROW_PTR2(gradOutput, h), gradOutput->stride[1],
+ COL_PTR2(gradWeight, offset), gradWeight->stride[0]);
+ } else {
+ THError(
+ "index out of bound. accGradParameters: %d not between 1 and %d",
+ offset + 1,
+ inDim);
+ }
+ }
+
+ // gradBias += gradOutput
+ THTensor* buf = THTensor_(new)();
+ THTensor_(sum)(buf, gradOutput, 0);
+ THTensor_(cadd)(gradBias, gradBias, scale, buf);
+ THTensor_(free)(buf);
+
+ if (weightDecay != 0) {
+ THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight);
+ }
+}
+
+void THNN_(SparseLinear_legacyAccGradParameters)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradWeight,
+ THTensor *gradBias,
+ THTensor *weight,
+ THTensor *bias,
+ real weightDecay,
+ real scale)
+{
+ long h, i;
+ long outDim = THTensor_(size)(weight, 0);
+ long inDim = THTensor_(size)(weight, 1);
+
+ THArgCheck(THNN_(checkLegacyInput)(input), 2,
"input size must be batchsize x nnz x 2");
THArgCheck(THNN_(checkSize2D)(gradWeight, outDim, inDim), 4,
"gradWeight size wrong");
@@ -279,51 +390,7 @@
}
}
-void THNN_(SparseLinear_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *weight)
-{
- long h, i;
- long outDim = weight->size[0];
- long inDim = weight->size[1];
-
- THArgCheck(THNN_(checkInput)(input), 2,
- "input must be a batchSize x nnz x 2 tensor");
- THArgCheck(THTensor_(isContiguous)(gradInput), 4,
- "gradInput must be contiguous");
- THArgCheck(THTensor_(isContiguous)(gradOutput), 3,
- "gradOutput must be contiguous");
-
- long batchSize = THTensor_(size)(input, 0);
- long nnz = THTensor_(size)(input, 1);
- THTensor_(resize2d)(gradOutput, batchSize, outDim);
- THTensor_(resize3d)(gradInput, batchSize, nnz, 2);
-
-#pragma omp parallel for private(h, i) schedule(static) if ( \
- batchSize > 1 && batchSize * nnz * outDim > 10000)
- for (h = 0; h < batchSize; h++) {
- for (i = 0; i < nnz; ++i) {
- long offset = (long)(THTensor_(get3d)(input, h, i, 0)) - 1;
- THTensor_(set3d)(gradInput, h, i, 0, offset + 1);
-
- if (offset >= 0 && offset < inDim) {
- real val = THBlas_(dot)(
- outDim,
- ROW_PTR2(gradOutput, h), gradOutput->stride[1],
- COL_PTR2(weight, offset), weight->stride[0]);
- THTensor_(set3d)(gradInput, h, i, 1, val);
- } else {
- THError(
- "index out of bound. updateGradInput: %d not between 1 and %d",
- offset + 1,
- inDim);
- }
- }
- }
-}
+void THNN_(SparseLinear_cudaClearState)(THNNState *state) {}
#undef ROW_PTR2
#undef COL_PTR2
diff --git a/generic/THNN.h b/generic/THNN.h
index 86c63da..544d317 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -342,16 +342,24 @@
THTensor *input,
THTensor *output,
THTensor *weight,
- THTensor *bias,
- THTensor *cudaBuffer,
- THTensor *shardBuffer);
-TH_API void THNN_(SparseLinear_updateGradInput)(
+ THTensor *bias);
+TH_API void THNN_(SparseLinear_accGradParameters)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *weight);
-TH_API void THNN_(SparseLinear_accGradParameters)(
+ THTensor *gradWeight,
+ THTensor *gradBias,
+ THTensor *weight,
+ THTensor *bias,
+ real weightDecay,
+ real scale);
+TH_API void THNN_(SparseLinear_legacyUpdateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ THTensor *weight,
+ THTensor *bias);
+TH_API void THNN_(SparseLinear_legacyAccGradParameters)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
@@ -374,6 +382,7 @@
THTensor *gradBias,
THTensor *lastInput,
real learningRate);
+TH_API void THNN_(SparseLinear_cudaClearState)(THNNState *state);
TH_API void THNN_(Sqrt_updateOutput)(
THNNState *state,