Add functional version of AbsCriterion using metatable call
THNN state is now passed implicitely.
diff --git a/generic/AbsCriterion.c b/generic/AbsCriterion.c
index 397e9dd..f14181c 100644
--- a/generic/AbsCriterion.c
+++ b/generic/AbsCriterion.c
@@ -2,53 +2,28 @@
#define TH_GENERIC_FILE "generic/AbsCriterion.c"
#else
-static int nn_(AbsCriterion_updateOutput)(lua_State *L)
+void THNN_(AbsCriterion_updateOutput)(THNNState *state, THTensor *input, THTensor *target, real *output, bool sizeAverage)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
- int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
- real sum;
+ real sum = 0;
- sum = 0;
TH_TENSOR_APPLY2(real, input, real, target,
- sum += fabs(*input_data - *target_data);)
+ sum += fabs(*input_data - *target_data);
+ );
- if(sizeAverage)
+ if (sizeAverage)
sum /= THTensor_(nElement)(input);
- lua_pushnumber(L, sum);
- lua_setfield(L, 1, "output");
-
- lua_pushnumber(L, sum);
- return 1;
+ *output = sum;
}
-static int nn_(AbsCriterion_updateGradInput)(lua_State *L)
+void THNN_(AbsCriterion_updateGradInput)(THNNState *state, THTensor *input, THTensor *target, THTensor *gradInput, bool sizeAverage)
{
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
- THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
- int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
- THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
- *gradInput_data = ( (*input_data - *target_data) >= 0 ? norm : -norm);)
-
- return 1;
-}
-
-static const struct luaL_Reg nn_(AbsCriterion__) [] = {
- {"AbsCriterion_updateOutput", nn_(AbsCriterion_updateOutput)},
- {"AbsCriterion_updateGradInput", nn_(AbsCriterion_updateGradInput)},
- {NULL, NULL}
-};
-
-static void nn_(AbsCriterion_init)(lua_State *L)
-{
- luaT_pushmetatable(L, torch_Tensor);
- luaT_registeratname(L, nn_(AbsCriterion__), "nn");
- lua_pop(L,1);
+ *gradInput_data = (*input_data - *target_data) >= 0 ? norm : -norm;
+ );
}
#endif
diff --git a/generic/THNN.h b/generic/THNN.h
index 8d74ae1..00f5fed 100644
--- a/generic/THNN.h
+++ b/generic/THNN.h
@@ -16,4 +16,17 @@
THTensor *gradOutput,
THTensor *gradInput);
+TH_API void THNN_(AbsCriterion_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ real *output,
+ bool sizeAverage);
+TH_API void THNN_(AbsCriterion_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *gradInput,
+ bool sizeAverage);
+
#endif
diff --git a/init.c b/init.c
index 4488afc..fbd9c50 100644
--- a/init.c
+++ b/init.c
@@ -6,3 +6,6 @@
#include "generic/Abs.c"
#include "THGenerateFloatTypes.h"
+
+#include "generic/AbsCriterion.c"
+#include "THGenerateFloatTypes.h"