Adding table input support for batched SparseLinear, implementing gradInput correctly, fixing other bugs
diff --git a/generic/SparseLinear.c b/generic/SparseLinear.c
index a84e030..2e24d91 100644
--- a/generic/SparseLinear.c
+++ b/generic/SparseLinear.c
@@ -5,15 +5,21 @@
 #ifdef _OPENMP
 #include <omp.h>
 #endif
+#include <stdio.h>
 
 #define ROW_PTR2(t, r) (THTensor_(data)(t) + (r) * (t)->stride[0])
 #define COL_PTR2(t, c) (THTensor_(data)(t) + (c) * (t)->stride[1])
 
-static bool THNN_(checkInput)(THTensor* t)
+static bool THNN_(checkLegacyInput)(THTensor* t)
 {
   return t->nDimension == 3 && t->size[2] == 2;
 }
 
+static bool THNN_(checkInput)(THTensor* t)
+{
+  return t->nDimension == 2 && t->size[1] == 3;
+}
+
 static bool THNN_(checkSize2D)(THTensor* t, long size0, long size1)
 {
   return t->nDimension == 2 && t->size[0] == size0 && t->size[1] == size1;
@@ -41,15 +47,61 @@
           THTensor *input,
           THTensor *output,
           THTensor *weight,
-          THTensor *bias,
-          THTensor *cudaBuffer,
-          THTensor *shardBuffer)
+          THTensor *bias)
+{
+  long h, i;
+  long outDim = THTensor_(size)(weight, 0);
+  long inDim = THTensor_(size)(weight, 1);
+  long batchSize = THTensor_(size)(output, 0);
+
+  THArgCheck(THNN_(checkInput)(input), 2, "input must be in coo format, nnz x 3");
+  THArgCheck(THTensor_(isContiguous)(output), 3, "output must be contiguous");
+  THArgCheck(THNN_(checkSize1D)(bias, outDim), 5, "bias size wrong");
+
+  long nnz = THTensor_(size)(input, 0);
+
+  // output = weight * input + bias
+  THTensor_(zero)(output);
+#pragma omp parallel for private(i) schedule(static) if (nnz * outDim > 10000)
+  for (i = 0; i < nnz; i++) {
+    real val = THNN_(get2d)(input, i, 2);
+    if (val == 0) {
+      continue;
+    }
+
+    long offset = (long)(THNN_(get2d)(input, i, 1)) - 1;
+    long h = (long)(THNN_(get2d)(input, i, 0)) - 1;
+    if (offset >= 0 && offset < inDim) {
+      THBlas_(axpy)(outDim,
+                    val,
+                    COL_PTR2(weight, offset), weight->stride[0],
+                    ROW_PTR2(output, h), output->stride[1]);
+    } else {
+      THError("index out of bound. updateOutput: %d not between 1 and %d",
+              offset + 1, inDim);
+    }
+  }
+
+  THTensor* output_row = THTensor_(new)();
+  for (h = 0; h < batchSize; h++) {
+    THTensor_(select)(output_row, output, 0, h);
+    THTensor_(cadd)(output_row, bias, 1.0, output_row);
+  }
+  THTensor_(free)(output_row);
+}
+
+void THNN_(SparseLinear_legacyUpdateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          THTensor *weight,
+          THTensor *bias)
 {
   long h, i;
   long outDim = THTensor_(size)(weight, 0);
   long inDim = THTensor_(size)(weight, 1);
 
-  THArgCheck(THNN_(checkInput)(input), 2, "input size must be batchsize x nnz x 2");
+  THArgCheck(THNN_(checkLegacyInput)(input), 2, "input size must be batchsize x nnz x 2");
   THArgCheck(THTensor_(isContiguous)(output), 3, "output must be contiguous");
   THArgCheck(THNN_(checkSize1D)(bias, outDim), 5, "bias size wrong");
 
@@ -105,6 +157,65 @@
   long inDim = THTensor_(size)(weight, 1);
 
   THArgCheck(THNN_(checkInput)(input), 2,
+             "input must be in coo format, nnz x 3");
+  THArgCheck(THNN_(checkSize2D)(gradWeight, outDim, inDim), 4,
+             "gradWeight size wrong");
+  THArgCheck(THNN_(checkSize1D)(gradBias, outDim), 5,
+             "gradBias size wrong");
+  THArgCheck(THTensor_(isContiguous)(gradOutput), 1,
+             "gradOutput must be contiguous");
+
+  long nnz = THTensor_(size)(input, 0);
+  // THTensor_(resize2d)(gradOutput, batchSize, outDim);
+
+  // gradWeight += gradOutput * input
+#pragma omp parallel for private(h, i) schedule(static) if (\
+  nnz * outDim > 10000)
+  for (i = 0; i < nnz; i++) {
+    real val = scale * THNN_(get2d)(input, i, 2);
+
+    long offset = (long)(THNN_(get2d)(input, i, 1)) - 1;
+    long h = (long)(THNN_(get2d)(input, i, 0)) - 1;
+    if (offset >= 0 && offset < inDim) {
+      THBlas_(axpy)(outDim,
+          val,
+          ROW_PTR2(gradOutput, h), gradOutput->stride[1],
+          COL_PTR2(gradWeight, offset), gradWeight->stride[0]);
+    } else {
+      THError(
+          "index out of bound. accGradParameters: %d not between 1 and %d",
+          offset + 1,
+          inDim);
+    }
+  }
+
+  // gradBias += gradOutput
+  THTensor* buf = THTensor_(new)();
+  THTensor_(sum)(buf, gradOutput, 0);
+  THTensor_(cadd)(gradBias, gradBias, scale, buf);
+  THTensor_(free)(buf);
+
+  if (weightDecay != 0) {
+    THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight);
+  }
+}
+
+void THNN_(SparseLinear_legacyAccGradParameters)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradWeight,
+          THTensor *gradBias,
+          THTensor *weight,
+          THTensor *bias,
+          real weightDecay,
+          real scale)
+{
+  long h, i;
+  long outDim = THTensor_(size)(weight, 0);
+  long inDim = THTensor_(size)(weight, 1);
+
+  THArgCheck(THNN_(checkLegacyInput)(input), 2,
              "input size must be batchsize x nnz x 2");
   THArgCheck(THNN_(checkSize2D)(gradWeight, outDim, inDim), 4,
              "gradWeight size wrong");
@@ -279,51 +390,7 @@
   }
 }
 
-void THNN_(SparseLinear_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *weight)
-{
-  long h, i;
-  long outDim = weight->size[0];
-  long inDim = weight->size[1];
-
-  THArgCheck(THNN_(checkInput)(input), 2,
-             "input must be a batchSize x nnz x 2 tensor");
-  THArgCheck(THTensor_(isContiguous)(gradInput), 4,
-             "gradInput must be contiguous");
-  THArgCheck(THTensor_(isContiguous)(gradOutput), 3,
-             "gradOutput must be contiguous");
-
-  long batchSize = THTensor_(size)(input, 0);
-  long nnz = THTensor_(size)(input, 1);
-  THTensor_(resize2d)(gradOutput, batchSize, outDim);
-  THTensor_(resize3d)(gradInput, batchSize, nnz, 2);
-
-#pragma omp parallel for private(h, i) schedule(static) if (    \
-  batchSize > 1 && batchSize * nnz * outDim > 10000)
-  for (h = 0; h < batchSize; h++) {
-    for (i = 0; i < nnz; ++i) {
-      long offset = (long)(THTensor_(get3d)(input, h, i, 0)) - 1;
-      THTensor_(set3d)(gradInput, h, i, 0, offset + 1);
-
-      if (offset >= 0 && offset < inDim) {
-        real val = THBlas_(dot)(
-          outDim,
-          ROW_PTR2(gradOutput, h), gradOutput->stride[1],
-          COL_PTR2(weight, offset), weight->stride[0]);
-        THTensor_(set3d)(gradInput, h, i, 1, val);
-      } else {
-        THError(
-          "index out of bound. updateGradInput: %d not between 1 and %d",
-          offset + 1,
-          inDim);
-      }
-    }
-  }
-}
+void THNN_(SparseLinear_cudaClearState)(THNNState *state) {}
 
 #undef ROW_PTR2
 #undef COL_PTR2
diff --git a/generic/THNN.h b/generic/THNN.h
index 86c63da..544d317 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -342,16 +342,24 @@
           THTensor *input,
           THTensor *output,
           THTensor *weight,
-          THTensor *bias,
-          THTensor *cudaBuffer,
-          THTensor *shardBuffer);
-TH_API void THNN_(SparseLinear_updateGradInput)(
+          THTensor *bias);
+TH_API void THNN_(SparseLinear_accGradParameters)(
           THNNState *state,
           THTensor *input,
           THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *weight);
-TH_API void THNN_(SparseLinear_accGradParameters)(
+          THTensor *gradWeight,
+          THTensor *gradBias,
+          THTensor *weight,
+          THTensor *bias,
+          real weightDecay,
+          real scale);
+TH_API void THNN_(SparseLinear_legacyUpdateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          THTensor *weight,
+          THTensor *bias);
+TH_API void THNN_(SparseLinear_legacyAccGradParameters)(
           THNNState *state,
           THTensor *input,
           THTensor *gradOutput,
@@ -374,6 +382,7 @@
           THTensor *gradBias,
           THTensor *lastInput,
           real learningRate);
+TH_API void THNN_(SparseLinear_cudaClearState)(THNNState *state);
 
 TH_API void THNN_(Sqrt_updateOutput)(
           THNNState *state,