Add THNN conversion of {RReLU, Sigmoid, SmoothL1Criterion,SoftMax, SoftPlus}
diff --git a/generic/MultiMarginCriterion.c b/generic/MultiMarginCriterion.c
index 9cb1686..6445bb0 100644
--- a/generic/MultiMarginCriterion.c
+++ b/generic/MultiMarginCriterion.c
@@ -2,7 +2,7 @@
#define TH_GENERIC_FILE "generic/MultiMarginCriterion.c"
#else
-void THNN_(MultiMarginCriterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor* output, bool sizeAverage, int p)
+void THNN_(MultiMarginCriterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor *output, bool sizeAverage, int p)
{
real *input_data, *target_data;
long nframe, dim;
diff --git a/generic/RReLU.c b/generic/RReLU.c
index 19b92b6..74c5df5 100644
--- a/generic/RReLU.c
+++ b/generic/RReLU.c
@@ -2,34 +2,24 @@
#define TH_GENERIC_FILE "generic/RReLU.c"
#else
-static int nn_(RReLU_updateOutput)(lua_State *L)
+void THNN_(RReLU_updateOutput)(THNNState *state, THTensor *input, THTensor *output, THTensor *noise, real lower, real upper, bool train, bool inplace, THGenerator *generator)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- THTensor *noise = luaT_getfieldcheckudata(L, 1, "noise", torch_Tensor);
- real lower = luaT_getfieldchecknumber(L, 1, "lower");
- real upper = luaT_getfieldchecknumber(L, 1, "upper");
- int train = luaT_getfieldcheckboolean(L, 1, "train");
- int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
-
if (train)
{
// get default random generator
- lua_getglobal(L, "torch");
- THGenerator *generator = luaT_getfieldcheckudata(L, -1, "_gen", torch_Generator);
- lua_pop(L, 2);
-
THTensor_(resizeAs)(noise, input);
if (inplace)
{
- TH_TENSOR_APPLY2(real, input, real, noise, \
- if (*input_data <= 0) { \
- const real r = (real)THRandom_uniform(generator, lower, upper); \
- *input_data = (*input_data) * r; \
- *noise_data = r; \
- } \
- else { \
- *noise_data = 1; \
+ TH_TENSOR_APPLY2(real, input, real, noise,
+ if (*input_data <= 0)
+ {
+ const real r = (real)THRandom_uniform(generator, lower, upper);
+ *input_data = (*input_data) * r;
+ *noise_data = r;
+ }
+ else
+ {
+ *noise_data = 1;
}
);
THTensor_(set)(output, input);
@@ -37,15 +27,17 @@
else
{
THTensor_(resizeAs)(output, input);
- TH_TENSOR_APPLY3(real, input, real, output, real, noise, \
- if (*input_data <= 0) { \
- const real r = (real)THRandom_uniform(generator, lower, upper); \
- *output_data = (*input_data) * r; \
- *noise_data = r; \
- } \
- else { \
+ TH_TENSOR_APPLY3(real, input, real, output, real, noise,
+ if (*input_data <= 0)
+ {
+ const real r = (real)THRandom_uniform(generator, lower, upper);
+ *output_data = (*input_data) * r;
+ *noise_data = r;
+ }
+ else
+ {
*output_data = *input_data;
- *noise_data = 1; \
+ *noise_data = 1;
}
);
}
@@ -55,9 +47,10 @@
const real negSlope = (lower + upper) / 2;
if (inplace)
{
- TH_TENSOR_APPLY(real, input, \
- if (*input_data <= 0) { \
- *input_data = *input_data * negSlope; \
+ TH_TENSOR_APPLY(real, input,
+ if (*input_data <= 0)
+ {
+ *input_data = *input_data * negSlope;
}
);
THTensor_(set)(output, input);
@@ -65,26 +58,16 @@
else
{
THTensor_(resizeAs)(output, input);
- TH_TENSOR_APPLY2(real, input, real, output, \
- const real r = (*input_data) <= 0 ? negSlope : 1; \
- *output_data = *input_data * r; \
+ TH_TENSOR_APPLY2(real, input, real, output,
+ const real r = (*input_data) <= 0 ? negSlope : 1;
+ *output_data = *input_data * r;
);
}
}
- return 1;
}
-static int nn_(RReLU_updateGradInput)(lua_State *L)
+void THNN_(RReLU_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *noise, real lower, real upper, bool train, bool inplace)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
- THTensor *noise = luaT_getfieldcheckudata(L, 1, "noise", torch_Tensor);
- real lower = luaT_getfieldchecknumber(L, 1, "lower");
- real upper = luaT_getfieldchecknumber(L, 1, "upper");
- int train = luaT_getfieldcheckboolean(L, 1, "train");
- int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
-
if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU
{
// multiply the gradient by the noise tensor
@@ -105,35 +88,22 @@
const real negSlope = (lower + upper) / 2;
if (inplace)
{
- TH_TENSOR_APPLY2(real, gradOutput, real, input, \
- if (*input_data <= 0) { \
- *gradOutput_data = (*gradOutput_data) * negSlope; \
- } \
+ TH_TENSOR_APPLY2(real, gradOutput, real, input,
+ if (*input_data <= 0)
+ {
+ *gradOutput_data = (*gradOutput_data) * negSlope;
+ }
);
THTensor_(set)(gradInput, gradOutput);
}
else
{
THTensor_(resizeAs)(gradInput, input);
- TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
- *gradInput_data = (*input_data) <= 0 ? (*gradOutput_data) * negSlope : (*gradOutput_data); \
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
+ *gradInput_data = (*input_data) <= 0 ? (*gradOutput_data) * negSlope : (*gradOutput_data);
);
}
}
- return 1;
-}
-
-static const struct luaL_Reg nn_(RReLU__) [] = {
- { "RReLU_updateOutput", nn_(RReLU_updateOutput) },
- { "RReLU_updateGradInput", nn_(RReLU_updateGradInput) },
- { NULL, NULL }
-};
-
-static void nn_(RReLU_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(RReLU__), "nn");
- lua_pop(L, 1);
}
#endif
diff --git a/generic/Sigmoid.c b/generic/Sigmoid.c
index 057ebc4..f58d33b 100644
--- a/generic/Sigmoid.c
+++ b/generic/Sigmoid.c
@@ -2,43 +2,22 @@
#define TH_GENERIC_FILE "generic/Sigmoid.c"
#else
-static int nn_(Sigmoid_updateOutput)(lua_State *L)
+void THNN_(Sigmoid_updateOutput)(THNNState *state, THTensor *input, THTensor *output)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-
THTensor_(resizeAs)(output, input);
- TH_TENSOR_APPLY2(real, output, real, input, \
- *output_data = 1./(1.+ exp(- *input_data));)
-
- return 1;
+ TH_TENSOR_APPLY2(real, output, real, input,
+ *output_data = 1./(1.+ exp(- *input_data));
+ );
}
-static int nn_(Sigmoid_updateGradInput)(lua_State *L)
+void THNN_(Sigmoid_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output)
{
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
-
THTensor_(resizeAs)(gradInput, output);
- TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
- real z = *output_data; \
- *gradInput_data = *gradOutput_data * (1. - z) * z;)
- return 1;
-}
-
-static const struct luaL_Reg nn_(Sigmoid__) [] = {
- {"Sigmoid_updateOutput", nn_(Sigmoid_updateOutput)},
- {"Sigmoid_updateGradInput", nn_(Sigmoid_updateGradInput)},
- {NULL, NULL}
-};
-
-static void nn_(Sigmoid_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(Sigmoid__), "nn");
- lua_pop(L,1);
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
+ real z = *output_data;
+ *gradInput_data = *gradOutput_data * (1. - z) * z;
+ );
}
#endif
diff --git a/generic/SmoothL1Criterion.c b/generic/SmoothL1Criterion.c
index 51cab0c..3111b3d 100644
--- a/generic/SmoothL1Criterion.c
+++ b/generic/SmoothL1Criterion.c
@@ -2,59 +2,34 @@
#define TH_GENERIC_FILE "generic/SmoothL1Criterion.c"
#else
-static int nn_(SmoothL1Criterion_updateOutput)(lua_State *L)
+void THNN_(SmoothL1Criterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, THTensor *output, bool sizeAverage)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
- int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
- real sum;
-
- sum = 0;
+ real sum = 0;
TH_TENSOR_APPLY2(real, input, real, target,
- real z = fabs(*input_data - *target_data);
- sum += z < 1 ? 0.5*z*z : z - 0.5;)
+ real z = fabs(*input_data - *target_data);
+ sum += z < 1 ? 0.5*z*z : z - 0.5;
+ );
- if(sizeAverage)
+ if (sizeAverage)
sum /= THTensor_(nElement)(input);
- lua_pushnumber(L, sum);
- lua_setfield(L, 1, "output");
-
- lua_pushnumber(L, sum);
- return 1;
+ THTensor_(set1d)(output, 0, sum);
}
-static int nn_(SmoothL1Criterion_updateGradInput)(lua_State *L)
+void THNN_(SmoothL1Criterion_updateGradInput)(THNNState *state, THTensor *input, THTensor *target, THTensor *gradInput, bool sizeAverage)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
- int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
- real x = *input_data - *target_data;
- if(x < -1.)
- *gradInput_data = - norm;
- else if(x > 1.)
- *gradInput_data = norm;
- else
- *gradInput_data = norm * x;)
- return 1;
-}
-
-static const struct luaL_Reg nn_(SmoothL1Criterion__) [] = {
- {"SmoothL1Criterion_updateOutput", nn_(SmoothL1Criterion_updateOutput)},
- {"SmoothL1Criterion_updateGradInput", nn_(SmoothL1Criterion_updateGradInput)},
- {NULL, NULL}
-};
-
-static void nn_(SmoothL1Criterion_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(SmoothL1Criterion__), "nn");
- lua_pop(L,1);
+ real x = *input_data - *target_data;
+ if (x < -1.)
+ *gradInput_data = - norm;
+ else if (x > 1.)
+ *gradInput_data = norm;
+ else
+ *gradInput_data = norm * x;
+ );
}
#endif
diff --git a/generic/SoftMax.c b/generic/SoftMax.c
index 0201aaf..598d35e 100644
--- a/generic/SoftMax.c
+++ b/generic/SoftMax.c
@@ -2,40 +2,40 @@
#define TH_GENERIC_FILE "generic/SoftMax.c"
#else
-static int nn_(SoftMax_updateOutput)(lua_State *L)
+void THNN_(SoftMax_updateOutput)(THNNState *state, THTensor *input, THTensor *output)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
real *input_data, *output_data;
long nframe = 0, dim = 0, stride = 0;
long t;
- if(input->nDimension == 1)
+ if (input->nDimension == 1)
{
nframe = 1;
dim = input->size[0];
stride = 1;
}
- else if(input->nDimension == 2)
+ else if (input->nDimension == 2)
{
nframe = input->size[0];
dim = input->size[1];
stride = 1;
}
- else if(input->nDimension == 3)
+ else if (input->nDimension == 3)
{
nframe = 1;
dim = input->size[0];
stride = input->size[1]*input->size[2];
}
- else if(input->nDimension == 4)
+ else if (input->nDimension == 4)
{
nframe = input->size[0];
dim = input->size[1];
stride = input->size[2]*input->size[3];
}
else
+ {
THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected");
+ }
input = THTensor_(newContiguous)(input);
THTensor_(resizeAs)(output, input);
@@ -44,7 +44,7 @@
output_data = THTensor_(data)(output);
#pragma omp parallel for private(t)
- for(t = 0; t < stride*nframe; t++)
+ for (t = 0; t < stride*nframe; t++)
{
real *input_ptr = input_data + (t/stride)*dim*stride + t % stride;
real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
@@ -53,62 +53,62 @@
accreal sum;
long d;
- for(d = 0; d < dim; d++) {
+ for (d = 0; d < dim; d++)
+ {
if (input_ptr[d*stride] >= inputMax) inputMax = input_ptr[d*stride];
}
sum = 0;
- for(d = 0; d < dim; d++) {
+ for (d = 0; d < dim; d++)
+ {
real z = THExpMinusApprox(inputMax - input_ptr[d*stride]);
output_ptr[d*stride] = z;
sum += z;
}
- for(d = 0; d < dim; d++) {
+ for (d = 0; d < dim; d++)
+ {
output_ptr[d*stride] *= 1/sum;
}
}
THTensor_(free)(input);
-
- return 1;
}
-static int nn_(SoftMax_updateGradInput)(lua_State *L)
+void THNN_(SoftMax_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output)
{
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
real *gradInput_data, *gradOutput_data, *output_data;
long nframe = 0, dim = 0, stride = 0;
long t;
- if(output->nDimension == 1)
+ if (output->nDimension == 1)
{
nframe = 1;
dim = output->size[0];
stride = 1;
}
- else if(output->nDimension == 2)
+ else if (output->nDimension == 2)
{
nframe = output->size[0];
dim = output->size[1];
stride = 1;
}
- else if(output->nDimension == 3)
+ else if (output->nDimension == 3)
{
nframe = 1;
dim = output->size[0];
stride = output->size[1]*output->size[2];
}
- else if(output->nDimension == 4)
+ else if (output->nDimension == 4)
{
nframe = output->size[0];
dim = output->size[1];
stride = output->size[2]*output->size[3];
}
else
+ {
THError("1D, 2D, 3D or 4D tensor expected");
+ }
gradOutput = THTensor_(newContiguous)(gradOutput);
output = THTensor_(newContiguous)(output);
@@ -119,7 +119,7 @@
gradOutput_data = THTensor_(data)(gradOutput);
#pragma omp parallel for private(t)
- for(t = 0; t < stride*nframe; t++)
+ for (t = 0; t < stride*nframe; t++)
{
real *gradInput_ptr = gradInput_data + (t/stride)*dim*stride + t % stride;
real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
@@ -127,30 +127,15 @@
long d;
accreal sum = 0;
- for(d = 0; d < dim; d++)
+ for (d = 0; d < dim; d++)
sum += (accreal)gradOutput_ptr[d*stride] * output_ptr[d*stride];
- for(d = 0; d < dim; d++)
+ for (d = 0; d < dim; d++)
gradInput_ptr[d*stride] = output_ptr[d*stride] * (gradOutput_ptr[d*stride] - sum);
}
THTensor_(free)(gradOutput);
THTensor_(free)(output);
-
- return 1;
-}
-
-static const struct luaL_Reg nn_(SoftMax__) [] = {
- {"SoftMax_updateOutput", nn_(SoftMax_updateOutput)},
- {"SoftMax_updateGradInput", nn_(SoftMax_updateGradInput)},
- {NULL, NULL}
-};
-
-static void nn_(SoftMax_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(SoftMax__), "nn");
- lua_pop(L,1);
}
#endif
diff --git a/generic/SoftPlus.c b/generic/SoftPlus.c
index 81f2a7c..76c9c1c 100644
--- a/generic/SoftPlus.c
+++ b/generic/SoftPlus.c
@@ -2,55 +2,29 @@
#define TH_GENERIC_FILE "generic/SoftPlus.c"
#else
-static int nn_(SoftPlus_updateOutput)(lua_State *L)
+void THNN_(SoftPlus_updateOutput)(THNNState *state, THTensor *input, THTensor *output, real beta, real threshold)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- real beta = luaT_getfieldchecknumber(L, 1, "beta");
- real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
-
THTensor_(resizeAs)(output, input);
- /* f(x) = 1/beta * log(1 + exp(beta * x)) */
-
+ // f(x) = 1/beta * log(1 + exp(beta * x))
TH_TENSOR_APPLY2(real, output, real, input, \
- *output_data = (*input_data * beta) > threshold ? *input_data : THLog1p(exp(*input_data * beta)) / beta;)
-
- return 1;
+ *output_data = (*input_data * beta) > threshold ? *input_data : THLog1p(exp(*input_data * beta)) / beta;
+ );
}
-static int nn_(SoftPlus_updateGradInput)(lua_State *L)
+void THNN_(SoftPlus_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output, real beta, real threshold)
{
- THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
- real beta = luaT_getfieldchecknumber(L, 1, "beta");
- real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
-
- /* d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
- SINCE
- y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
- THEREFORE:
- d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y) */
-
THTensor_(resizeAs)(gradInput, output);
- TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
- real z = exp(*output_data * beta); \
- *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;)
- return 1;
-}
-
-static const struct luaL_Reg nn_(SoftPlus__) [] = {
- {"SoftPlus_updateOutput", nn_(SoftPlus_updateOutput)},
- {"SoftPlus_updateGradInput", nn_(SoftPlus_updateGradInput)},
- {NULL, NULL}
-};
-
-static void nn_(SoftPlus_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(SoftPlus__), "nn");
- lua_pop(L,1);
+
+ // d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
+ // SINCE
+ // y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
+ // THEREFORE:
+ // d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y)
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
+ real z = exp(*output_data * beta);
+ *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;
+ );
}
#endif
diff --git a/generic/THNN.h b/generic/THNN.h
index 871ca97..f3d2ce8 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -234,6 +234,77 @@
THIndex_t nOutputPlane,
real scale);
+TH_API void THNN_(RReLU_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ THTensor *noise,
+ real lower,
+ real upper,
+ bool train,
+ bool inplace,
+ THGenerator *generator);
+TH_API void THNN_(RReLU_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *noise,
+ real lower,
+ real upper,
+ bool train,
+ bool inplace);
+
+TH_API void THNN_(Sigmoid_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output);
+TH_API void THNN_(Sigmoid_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *output);
+
+TH_API void THNN_(SmoothL1Criterion_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *output,
+ bool sizeAverage);
+TH_API void THNN_(SmoothL1Criterion_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *gradInput,
+ bool sizeAverage);
+
+TH_API void THNN_(SoftMax_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output);
+TH_API void THNN_(SoftMax_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *output);
+
+TH_API void THNN_(SoftPlus_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ real beta,
+ real threshold);
+TH_API void THNN_(SoftPlus_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *output,
+ real beta,
+ real threshold);
+
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
THNNState *state,
THTensor *input,