Add THCUNN/ffi conversion of Abs and AbsCriterion
diff --git a/Abs.cu b/Abs.cu
index b21ccb6..7514d44 100644
--- a/Abs.cu
+++ b/Abs.cu
@@ -1,5 +1,4 @@
-#include "utils.h"
-#include "THCApply.cuh"
+#include "THCUNN.h"
 
 struct absupdateOutput_functor
 {
@@ -9,15 +8,11 @@
   }
 };
 
-static int cunn_Abs_updateOutput(lua_State *L)
+void THNN_CudaAbs_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output)
 {
-  THCState *state = getCutorchState(L);
-  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
-  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
   THAssert(THCudaTensor_checkGPU(state, 2, input, output));
   THCudaTensor_resizeAs(state, output, input);
   THCudaTensor_pointwiseApply2(state, output, input, absupdateOutput_functor());
-  return 1;
 }
 
 struct absupdateGradInput_functor
@@ -28,27 +23,9 @@
   }
 };
 
-static int cunn_Abs_updateGradInput(lua_State *L)
+void THNN_CudaAbs_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput)
 {
-  THCState *state = getCutorchState(L);
-  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
-  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
-  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
   THAssert(THCudaTensor_checkGPU(state, 3, input, gradOutput, gradInput));
   THCudaTensor_resizeAs(state, gradInput, input);
   THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, absupdateGradInput_functor());
-  return 1;
-}
-
-static const struct luaL_Reg cunn_Abs__ [] = {
-  {"Abs_updateOutput", cunn_Abs_updateOutput},
-  {"Abs_updateGradInput", cunn_Abs_updateGradInput},
-  {NULL, NULL}
-};
-
-void cunn_Abs_init(lua_State *L)
-{
-  luaT_pushmetatable(L, "torch.CudaTensor");
-  luaT_registeratname(L, cunn_Abs__, "nn");
-  lua_pop(L,1);
 }
diff --git a/AbsCriterion.cu b/AbsCriterion.cu
index de46f8c..2f856c7 100644
--- a/AbsCriterion.cu
+++ b/AbsCriterion.cu
@@ -1,4 +1,4 @@
-#include "utils.h"
+#include "THCUNN.h"
 
 #include <thrust/fill.h>
 #include <thrust/functional.h>
@@ -17,15 +17,9 @@
     }
 };
 
-
-static int cunn_AbsCriterion_updateOutput(lua_State *L)
+void THNN_CudaAbsCriterion_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *target, float *output, bool sizeAverage)
 {
-  THCState *state = getCutorchState(L);
-  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
-  THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
-  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
   THAssert(THCudaTensor_checkGPU(state, 2, input, target));
-  float sum;
 
   long size = THCudaTensor_nElement(state, input);
 
@@ -34,19 +28,15 @@
 
   thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
   thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
-  sum = thrust::inner_product(input_data, input_data+size, target_data, (float) 0, thrust::plus<float>(), abs_functor());
+  float sum = thrust::inner_product(input_data, input_data+size, target_data, (float) 0, thrust::plus<float>(), abs_functor());
 
-  if(sizeAverage)
+  if (sizeAverage)
     sum /= size;
 
   THCudaTensor_free(state, input);
   THCudaTensor_free(state, target);
 
-  lua_pushnumber(L, sum);
-  lua_setfield(L, 1, "output");
-
-  lua_pushnumber(L, sum);
-  return 1;
+  *output = sum;
 }
 
 
@@ -62,13 +52,8 @@
     }
 };
 
-static int cunn_AbsCriterion_updateGradInput(lua_State *L)
+void THNN_CudaAbsCriterion_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *gradInput, bool sizeAverage)
 {
-  THCState *state = getCutorchState(L);
-  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
-  THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
-  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
-  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
   THAssert(THCudaTensor_checkGPU(state, 3, input, target, gradInput));
 
   long size = THCudaTensor_nElement(state, input);
@@ -87,18 +72,4 @@
 
   THCudaTensor_free(state, input);
   THCudaTensor_free(state, target);
-  return 1;
-}
-
-static const struct luaL_Reg cunn_AbsCriterion__ [] = {
-  {"AbsCriterion_updateOutput", cunn_AbsCriterion_updateOutput},
-  {"AbsCriterion_updateGradInput", cunn_AbsCriterion_updateGradInput},
-  {NULL, NULL}
-};
-
-void cunn_AbsCriterion_init(lua_State *L)
-{
-  luaT_pushmetatable(L, "torch.CudaTensor");
-  luaT_registeratname(L, cunn_AbsCriterion__, "nn");
-  lua_pop(L,1);
 }
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..dd58697
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,13 @@
+CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR)
+CMAKE_POLICY(VERSION 2.8)
+
+FILE(GLOB src-cuda *.cu)
+
+CUDA_ADD_LIBRARY(THCUNN MODULE ${src-cuda})
+
+INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})
+TARGET_LINK_LIBRARIES(THCUNN THC TH)
+
+INSTALL(TARGETS THCUNN
+  RUNTIME DESTINATION ${Torch_INSTALL_LIB_SUBDIR}
+  LIBRARY DESTINATION ${Torch_INSTALL_LIB_SUBDIR})
diff --git a/THCUNN.h b/THCUNN.h
new file mode 100644
index 0000000..f737396
--- /dev/null
+++ b/THCUNN.h
@@ -0,0 +1,25 @@
+#include <THC/THC.h>
+#include "THCApply.cuh"
+
+TH_API void THNN_CudaAbs_updateOutput(
+          THCState *state,
+          THCudaTensor *input,
+          THCudaTensor *output);
+TH_API void THNN_CudaAbs_updateGradInput(
+          THCState *state,
+          THCudaTensor *input,
+          THCudaTensor *gradOutput,
+          THCudaTensor *gradInput);
+
+TH_API void THNN_CudaAbsCriterion_updateOutput(
+          THCState *state,
+          THCudaTensor *input,
+          THCudaTensor *target,
+          float *output,
+          bool sizeAverage);
+TH_API void THNN_CudaAbsCriterion_updateGradInput(
+          THCState *state,
+          THCudaTensor *input,
+          THCudaTensor *target,
+          THCudaTensor *gradInput,
+          bool sizeAverage);