Add THNN conversion of {RReLU, Sigmoid, SmoothL1Criterion,SoftMax, SoftPlus}
diff --git a/generic/MultiMarginCriterion.c b/generic/MultiMarginCriterion.c
index 9cb1686..6445bb0 100644
--- a/generic/MultiMarginCriterion.c
+++ b/generic/MultiMarginCriterion.c
@@ -2,7 +2,7 @@
 #define TH_GENERIC_FILE "generic/MultiMarginCriterion.c"
 #else
 
-void THNN_(MultiMarginCriterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor* output, bool sizeAverage, int p)
+void THNN_(MultiMarginCriterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor *output, bool sizeAverage, int p)
 {
   real *input_data, *target_data;
   long nframe, dim;
diff --git a/generic/RReLU.c b/generic/RReLU.c
index 19b92b6..74c5df5 100644
--- a/generic/RReLU.c
+++ b/generic/RReLU.c
@@ -2,34 +2,24 @@
 #define TH_GENERIC_FILE "generic/RReLU.c"
 #else
 
-static int nn_(RReLU_updateOutput)(lua_State *L)
+void THNN_(RReLU_updateOutput)(THNNState *state, THTensor *input, THTensor *output, THTensor *noise, real lower, real upper, bool train, bool inplace, THGenerator *generator)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-  THTensor *noise = luaT_getfieldcheckudata(L, 1, "noise", torch_Tensor);
-  real lower = luaT_getfieldchecknumber(L, 1, "lower");
-  real upper = luaT_getfieldchecknumber(L, 1, "upper");
-  int train = luaT_getfieldcheckboolean(L, 1, "train");
-  int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
-  
   if (train)
   {
     // get default random generator
-    lua_getglobal(L, "torch");
-    THGenerator *generator = luaT_getfieldcheckudata(L, -1, "_gen", torch_Generator);
-    lua_pop(L, 2);
-
     THTensor_(resizeAs)(noise, input);
     if (inplace)
     {
-      TH_TENSOR_APPLY2(real, input, real, noise, \
-        if (*input_data <= 0) { \
-          const real r = (real)THRandom_uniform(generator, lower, upper); \
-          *input_data = (*input_data) * r; \
-          *noise_data = r; \
-        } \
-        else { \
-          *noise_data = 1; \
+      TH_TENSOR_APPLY2(real, input, real, noise,
+        if (*input_data <= 0)
+        {
+          const real r = (real)THRandom_uniform(generator, lower, upper);
+          *input_data = (*input_data) * r;
+          *noise_data = r;
+        }
+        else
+        {
+          *noise_data = 1;
         }
       );
       THTensor_(set)(output, input);
@@ -37,15 +27,17 @@
     else
     {
       THTensor_(resizeAs)(output, input);
-      TH_TENSOR_APPLY3(real, input, real, output, real, noise, \
-        if (*input_data <= 0) { \
-          const real r = (real)THRandom_uniform(generator, lower, upper); \
-          *output_data = (*input_data) * r; \
-          *noise_data = r; \
-        } \
-        else { \
+      TH_TENSOR_APPLY3(real, input, real, output, real, noise,
+        if (*input_data <= 0)
+        {
+          const real r = (real)THRandom_uniform(generator, lower, upper);
+          *output_data = (*input_data) * r;
+          *noise_data = r;
+        }
+        else
+        {
           *output_data = *input_data;
-          *noise_data = 1; \
+          *noise_data = 1;
         }
       );
     }
@@ -55,9 +47,10 @@
     const real negSlope = (lower + upper) / 2;
     if (inplace)
     {
-      TH_TENSOR_APPLY(real, input, \
-        if (*input_data <= 0) { \
-          *input_data = *input_data * negSlope; \
+      TH_TENSOR_APPLY(real, input,
+        if (*input_data <= 0)
+        {
+          *input_data = *input_data * negSlope;
         }
       );
       THTensor_(set)(output, input);
@@ -65,26 +58,16 @@
     else
     {
       THTensor_(resizeAs)(output, input);
-      TH_TENSOR_APPLY2(real, input, real, output, \
-        const real r = (*input_data) <= 0 ? negSlope : 1; \
-        *output_data = *input_data * r; \
+      TH_TENSOR_APPLY2(real, input, real, output,
+        const real r = (*input_data) <= 0 ? negSlope : 1;
+        *output_data = *input_data * r;
       );
     }
   }  
-  return 1;
 }
 
-static int nn_(RReLU_updateGradInput)(lua_State *L)
+void THNN_(RReLU_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *noise, real lower, real upper, bool train, bool inplace)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
-  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
-  THTensor *noise = luaT_getfieldcheckudata(L, 1, "noise", torch_Tensor);
-  real lower = luaT_getfieldchecknumber(L, 1, "lower");
-  real upper = luaT_getfieldchecknumber(L, 1, "upper");
-  int train = luaT_getfieldcheckboolean(L, 1, "train");
-  int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
-  
   if (train && upper - lower > 1E-6)    // e.g. if upper == lower, RReLU behaves like LeakyReLU
   {
     // multiply the gradient by the noise tensor
@@ -105,35 +88,22 @@
     const real negSlope = (lower + upper) / 2;
     if (inplace)
     {
-      TH_TENSOR_APPLY2(real, gradOutput, real, input, \
-        if (*input_data <= 0) { \
-         *gradOutput_data = (*gradOutput_data) * negSlope; \
-        } \
+      TH_TENSOR_APPLY2(real, gradOutput, real, input,
+        if (*input_data <= 0)
+        {
+          *gradOutput_data = (*gradOutput_data) * negSlope;
+        }
       );
       THTensor_(set)(gradInput, gradOutput);
     }
     else
     {
       THTensor_(resizeAs)(gradInput, input);
-      TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
-        *gradInput_data = (*input_data) <= 0 ? (*gradOutput_data) * negSlope : (*gradOutput_data); \
+      TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
+        *gradInput_data = (*input_data) <= 0 ? (*gradOutput_data) * negSlope : (*gradOutput_data);
       );
     }
   }
-  return 1;
-}
-
-static const struct luaL_Reg nn_(RReLU__) [] = {
-  { "RReLU_updateOutput", nn_(RReLU_updateOutput) },
-  { "RReLU_updateGradInput", nn_(RReLU_updateGradInput) },
-  { NULL, NULL }
-};
-
-static void nn_(RReLU_init)(lua_State *L)
-{
-  luaT_pushmetatable(L, torch_Tensor);
-  luaT_registeratname(L, nn_(RReLU__), "nn");
-  lua_pop(L, 1);
 }
 
 #endif
diff --git a/generic/Sigmoid.c b/generic/Sigmoid.c
index 057ebc4..f58d33b 100644
--- a/generic/Sigmoid.c
+++ b/generic/Sigmoid.c
@@ -2,43 +2,22 @@
 #define TH_GENERIC_FILE "generic/Sigmoid.c"
 #else
 
-static int nn_(Sigmoid_updateOutput)(lua_State *L)
+void THNN_(Sigmoid_updateOutput)(THNNState *state, THTensor *input, THTensor *output)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-
   THTensor_(resizeAs)(output, input);
 
-  TH_TENSOR_APPLY2(real, output, real, input, \
-                   *output_data = 1./(1.+ exp(- *input_data));)
-
-  return 1;
+  TH_TENSOR_APPLY2(real, output, real, input,
+    *output_data = 1./(1.+ exp(- *input_data));
+  );
 }
 
-static int nn_(Sigmoid_updateGradInput)(lua_State *L)
+void THNN_(Sigmoid_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output)
 {
-  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
-
   THTensor_(resizeAs)(gradInput, output);
-  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
-                   real z = *output_data; \
-                   *gradInput_data = *gradOutput_data * (1. - z) * z;)
-  return 1;
-}
-
-static const struct luaL_Reg nn_(Sigmoid__) [] = {
-  {"Sigmoid_updateOutput", nn_(Sigmoid_updateOutput)},
-  {"Sigmoid_updateGradInput", nn_(Sigmoid_updateGradInput)},
-  {NULL, NULL}
-};
-
-static void nn_(Sigmoid_init)(lua_State *L)
-{
-  luaT_pushmetatable(L, torch_Tensor);
-  luaT_registeratname(L, nn_(Sigmoid__), "nn");
-  lua_pop(L,1);
+  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
+    real z = *output_data;
+    *gradInput_data = *gradOutput_data * (1. - z) * z;
+  );
 }
 
 #endif
diff --git a/generic/SmoothL1Criterion.c b/generic/SmoothL1Criterion.c
index 51cab0c..3111b3d 100644
--- a/generic/SmoothL1Criterion.c
+++ b/generic/SmoothL1Criterion.c
@@ -2,59 +2,34 @@
 #define TH_GENERIC_FILE "generic/SmoothL1Criterion.c"
 #else
 
-static int nn_(SmoothL1Criterion_updateOutput)(lua_State *L)
+void THNN_(SmoothL1Criterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor *output, bool sizeAverage)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
-  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
-  real sum;
-
-  sum = 0;
+  real sum = 0;
   TH_TENSOR_APPLY2(real, input, real, target,
-                   real z = fabs(*input_data - *target_data);
-                   sum += z < 1 ? 0.5*z*z : z - 0.5;)
+    real z = fabs(*input_data - *target_data);
+    sum += z < 1 ? 0.5*z*z : z - 0.5;
+  );
 
-  if(sizeAverage)
+  if (sizeAverage)
     sum /= THTensor_(nElement)(input);
 
-  lua_pushnumber(L, sum);
-  lua_setfield(L, 1, "output");
-
-  lua_pushnumber(L, sum);
-  return 1;
+  THTensor_(set1d)(output, 0, sum);
 }
 
-static int nn_(SmoothL1Criterion_updateGradInput)(lua_State *L)
+void THNN_(SmoothL1Criterion_updateGradInput)(THNNState *state, THTensor *input, THTensor *target, THTensor *gradInput, bool sizeAverage)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
-  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
-  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
   real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
 
   THTensor_(resizeAs)(gradInput, input);
   TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
-                   real x = *input_data - *target_data;
-                   if(x < -1.)
-                     *gradInput_data = - norm;
-                   else if(x > 1.)
-                     *gradInput_data = norm;
-                   else
-                     *gradInput_data = norm * x;)
-  return 1;
-}
-
-static const struct luaL_Reg nn_(SmoothL1Criterion__) [] = {
-  {"SmoothL1Criterion_updateOutput", nn_(SmoothL1Criterion_updateOutput)},
-  {"SmoothL1Criterion_updateGradInput", nn_(SmoothL1Criterion_updateGradInput)},
-  {NULL, NULL}
-};
-
-static void nn_(SmoothL1Criterion_init)(lua_State *L)
-{
-  luaT_pushmetatable(L, torch_Tensor);
-  luaT_registeratname(L, nn_(SmoothL1Criterion__), "nn");
-  lua_pop(L,1);
+    real x = *input_data - *target_data;
+    if (x < -1.)
+     *gradInput_data = - norm;
+    else if (x > 1.)
+     *gradInput_data = norm;
+    else
+     *gradInput_data = norm * x;
+  );
 }
 
 #endif
diff --git a/generic/SoftMax.c b/generic/SoftMax.c
index 0201aaf..598d35e 100644
--- a/generic/SoftMax.c
+++ b/generic/SoftMax.c
@@ -2,40 +2,40 @@
 #define TH_GENERIC_FILE "generic/SoftMax.c"
 #else
 
-static int nn_(SoftMax_updateOutput)(lua_State *L)
+void THNN_(SoftMax_updateOutput)(THNNState *state, THTensor *input, THTensor *output)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
   real *input_data, *output_data;
   long nframe = 0, dim = 0, stride = 0;
   long t;
 
-  if(input->nDimension == 1)
+  if (input->nDimension == 1)
   {
     nframe = 1;
     dim = input->size[0];
     stride = 1;
   }
-  else if(input->nDimension == 2)
+  else if (input->nDimension == 2)
   {
     nframe = input->size[0];
     dim = input->size[1];
     stride = 1;
   }
-  else if(input->nDimension == 3)
+  else if (input->nDimension == 3)
   {
     nframe = 1;
     dim = input->size[0];
     stride = input->size[1]*input->size[2];
   }
-  else if(input->nDimension == 4)
+  else if (input->nDimension == 4)
   {
     nframe = input->size[0];
     dim = input->size[1];
     stride = input->size[2]*input->size[3];
   }
   else
+  {
     THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected");
+  }
 
   input = THTensor_(newContiguous)(input);
   THTensor_(resizeAs)(output, input);
@@ -44,7 +44,7 @@
   output_data = THTensor_(data)(output);
 
 #pragma omp parallel for private(t)
-  for(t = 0; t < stride*nframe; t++)
+  for (t = 0; t < stride*nframe; t++)
   {
     real *input_ptr = input_data + (t/stride)*dim*stride + t % stride;
     real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
@@ -53,62 +53,62 @@
     accreal sum;
 
     long d;
-    for(d = 0; d < dim; d++) {
+    for (d = 0; d < dim; d++)
+    {
       if (input_ptr[d*stride] >= inputMax) inputMax = input_ptr[d*stride];
     }
 
     sum = 0;
-    for(d = 0; d < dim; d++) {
+    for (d = 0; d < dim; d++)
+    {
       real z = THExpMinusApprox(inputMax - input_ptr[d*stride]);
       output_ptr[d*stride] = z;
       sum += z;
     }
 
-    for(d = 0; d < dim; d++) {
+    for (d = 0; d < dim; d++)
+    {
       output_ptr[d*stride] *= 1/sum;
     }
   }
 
   THTensor_(free)(input);
-
-  return 1;
 }
 
-static int nn_(SoftMax_updateGradInput)(lua_State *L)
+void THNN_(SoftMax_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output)
 {
-  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
   real *gradInput_data, *gradOutput_data, *output_data;
   long nframe = 0, dim = 0, stride = 0;
   long t;
 
-  if(output->nDimension == 1)
+  if (output->nDimension == 1)
   {
     nframe = 1;
     dim = output->size[0];
     stride = 1;
   }
-  else if(output->nDimension == 2)
+  else if (output->nDimension == 2)
   {
     nframe = output->size[0];
     dim = output->size[1];
     stride = 1;
   }
-  else if(output->nDimension == 3)
+  else if (output->nDimension == 3)
   {
     nframe = 1;
     dim = output->size[0];
     stride = output->size[1]*output->size[2];
   }
-  else if(output->nDimension == 4)
+  else if (output->nDimension == 4)
   {
     nframe = output->size[0];
     dim = output->size[1];
     stride = output->size[2]*output->size[3];
   }
   else
+  {
     THError("1D, 2D, 3D or 4D tensor expected");
+  }
 
   gradOutput = THTensor_(newContiguous)(gradOutput);
   output = THTensor_(newContiguous)(output);
@@ -119,7 +119,7 @@
   gradOutput_data = THTensor_(data)(gradOutput);
 
 #pragma omp parallel for private(t)
-  for(t = 0; t < stride*nframe; t++)
+  for (t = 0; t < stride*nframe; t++)
   {
     real *gradInput_ptr = gradInput_data + (t/stride)*dim*stride + t % stride;
     real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
@@ -127,30 +127,15 @@
 
     long d;
     accreal sum = 0;
-    for(d = 0; d < dim; d++)
+    for (d = 0; d < dim; d++)
       sum += (accreal)gradOutput_ptr[d*stride] * output_ptr[d*stride];
 
-    for(d = 0; d < dim; d++)
+    for (d = 0; d < dim; d++)
       gradInput_ptr[d*stride] = output_ptr[d*stride] * (gradOutput_ptr[d*stride] - sum);
   }
 
   THTensor_(free)(gradOutput);
   THTensor_(free)(output);
-
-  return 1;
-}
-
-static const struct luaL_Reg nn_(SoftMax__) [] = {
-  {"SoftMax_updateOutput", nn_(SoftMax_updateOutput)},
-  {"SoftMax_updateGradInput", nn_(SoftMax_updateGradInput)},
-  {NULL, NULL}
-};
-
-static void nn_(SoftMax_init)(lua_State *L)
-{
-  luaT_pushmetatable(L, torch_Tensor);
-  luaT_registeratname(L, nn_(SoftMax__), "nn");
-  lua_pop(L,1);
 }
 
 #endif
diff --git a/generic/SoftPlus.c b/generic/SoftPlus.c
index 81f2a7c..76c9c1c 100644
--- a/generic/SoftPlus.c
+++ b/generic/SoftPlus.c
@@ -2,55 +2,29 @@
 #define TH_GENERIC_FILE "generic/SoftPlus.c"
 #else
 
-static int nn_(SoftPlus_updateOutput)(lua_State *L)
+void THNN_(SoftPlus_updateOutput)(THNNState *state, THTensor *input, THTensor *output, real beta, real threshold)
 {
-  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-  real beta = luaT_getfieldchecknumber(L, 1, "beta");
-  real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
-
   THTensor_(resizeAs)(output, input);
 
-  /* f(x) = 1/beta * log(1 + exp(beta * x)) */
-
+  // f(x) = 1/beta * log(1 + exp(beta * x))
   TH_TENSOR_APPLY2(real, output, real, input,               \
-    *output_data = (*input_data * beta) > threshold ? *input_data : THLog1p(exp(*input_data * beta)) / beta;)
-    
-    return 1;
+    *output_data = (*input_data * beta) > threshold ? *input_data : THLog1p(exp(*input_data * beta)) / beta;
+  );
 }
 
-static int nn_(SoftPlus_updateGradInput)(lua_State *L)
+void THNN_(SoftPlus_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output, real beta, real threshold)
 {
-  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
-  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
-  real beta = luaT_getfieldchecknumber(L, 1, "beta");
-  real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
-
-  /* d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
-     SINCE
-     y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
-     THEREFORE:
-     d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y) */
-
   THTensor_(resizeAs)(gradInput, output);
-  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,    \
-                   real z = exp(*output_data * beta);                  \
-                   *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;)
-    return 1;
-}
-
-static const struct luaL_Reg nn_(SoftPlus__) [] = {
-  {"SoftPlus_updateOutput", nn_(SoftPlus_updateOutput)},
-  {"SoftPlus_updateGradInput", nn_(SoftPlus_updateGradInput)},
-  {NULL, NULL}
-};
-
-static void nn_(SoftPlus_init)(lua_State *L)
-{
-  luaT_pushmetatable(L, torch_Tensor);
-  luaT_registeratname(L, nn_(SoftPlus__), "nn");
-  lua_pop(L,1);
+  
+  // d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
+  // SINCE
+  // y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
+  // THEREFORE:
+  // d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y)
+  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
+    real z = exp(*output_data * beta);
+    *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;
+  );
 }
 
 #endif
diff --git a/generic/THNN.h b/generic/THNN.h
index 871ca97..f3d2ce8 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -234,6 +234,77 @@
           THIndex_t nOutputPlane,
           real scale);
 
+TH_API void THNN_(RReLU_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          THTensor *noise,
+          real lower,
+          real upper,
+          bool train,
+          bool inplace,
+          THGenerator *generator);
+TH_API void THNN_(RReLU_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          THTensor *noise,
+          real lower,
+          real upper,
+          bool train,
+          bool inplace);
+
+TH_API void THNN_(Sigmoid_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output);
+TH_API void THNN_(Sigmoid_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          THTensor *output);
+
+TH_API void THNN_(SmoothL1Criterion_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *target,
+          THTensor *output,
+          bool sizeAverage);
+TH_API void THNN_(SmoothL1Criterion_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *target,
+          THTensor *gradInput,
+          bool sizeAverage);
+
+TH_API void THNN_(SoftMax_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output);
+TH_API void THNN_(SoftMax_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          THTensor *output);
+
+TH_API void THNN_(SoftPlus_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          real beta,
+          real threshold);
+TH_API void THNN_(SoftPlus_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          THTensor *output,
+          real beta,
+          real threshold);
+
 TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
           THNNState *state,
           THTensor *input,