Move {RReLU.cu, Sigmoid.cu, SmoothL1Criterion.cu,..} to lib/THCUNN

Files moved:
RReLU.cu
Sigmoid.cu
SmoothL1Criterion.cu
SoftMax.cu
SoftPlus.cu
SoftShrink.cu
Sqrt.cu
Square.cu
Tanh.cu
Threshold.cu
diff --git a/RReLU.cu b/RReLU.cu
new file mode 100644
index 0000000..3c96103
--- /dev/null
+++ b/RReLU.cu
@@ -0,0 +1,208 @@
+#include "THCApply.cuh"
+#include "utils.h"
+#include "common.h"
+#include <curand.h>
+#include <curand_kernel.h>
+
+// copied from cutorch/lib/THC/THCTensorRandom.cu
+#define MAX_NUM_BLOCKS 64
+#define BLOCK_SIZE 256
+#define NUM_BLOCKS(n) min((int)THCCeilDiv(n, (long) BLOCK_SIZE), MAX_NUM_BLOCKS)
+
+__global__ void rreluUpdateOutputTrain(int n, curandStateMtgp32 *state, 
+  float *input, float* noise, float *output, double a, double b)
+{
+  CUDA_KERNEL_LOOP(i, n)
+  {
+    if (input[i] <= 0)
+    {
+      float r = curand_uniform(&state[blockIdx.x]);
+      r = r * (b-a) + a;
+      output[i] = input[i] * r;
+      noise[i] = r;
+    }
+    else
+    {
+      output[i] = input[i];
+      noise[i] = 1;
+    }
+  }
+}
+
+struct RReLUUpdateOutputEval_functor
+{
+  const float negSlope_;
+
+  RReLUUpdateOutputEval_functor(float negSlope) : negSlope_(negSlope) {}
+
+  __device__ __forceinline__ void operator()(float* out, float* in)
+  {
+    const float x = *in;
+    const float r = x <= 0 ? negSlope_ : 1;
+    *out = x * r;
+  }
+};
+
+struct RReLUUpdateOutputEvalIP_functor
+{
+  const float negSlope_;
+
+  RReLUUpdateOutputEvalIP_functor(float negSlope) : negSlope_(negSlope) {}
+
+  __device__ __forceinline__ void operator()(float* x)
+  {
+    if (*x <= 0)
+    {
+      *x = *x * negSlope_;
+    }
+  }
+};
+
+static int cunn_RReLU_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *noise = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "noise", "torch.CudaTensor");
+  double lower = luaT_getfieldchecknumber(L, 1, "lower");
+  double upper = luaT_getfieldchecknumber(L, 1, "upper");
+  int train = luaT_getfieldcheckboolean(L, 1, "train");
+  int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+  THAssert(THCudaTensor_checkGPU(state, 3, input, output, noise));
+  if (state->rngState->current_gen == NULL)
+  {
+    THError("Random number generators have not been initialized.");
+  }
+
+  if (train)
+  {
+    input = THCudaTensor_newContiguous(state, input);
+    THCudaTensor_resizeAs(state, noise, input);
+    float *input_data = THCudaTensor_data(state, input);
+    float *noise_data = THCudaTensor_data(state, noise);
+    long n = THCudaTensor_nElement(state, input);
+    if (inplace)
+    {
+      rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+        n, state->rngState->current_gen->gen_states, 
+        input_data, noise_data, input_data, lower, upper);
+      THCudaTensor_set(state, output, input);
+    }
+    else
+    {
+      THCudaTensor_resizeAs(state, output, input);
+      float *output_data = THCudaTensor_data(state, output);
+      rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+        n, state->rngState->current_gen->gen_states, 
+        input_data, noise_data, output_data, lower, upper);
+    }
+    THCudaTensor_free(state, input);
+  }
+  else
+  {
+    const double negSlope = (lower + upper) / 2;
+    if (inplace)
+    {
+      THCudaTensor_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
+      THCudaTensor_set(state, output, input);
+    }
+    else
+    {
+      THCudaTensor_resizeAs(state, output, input);
+      THCudaTensor_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
+    }
+  }
+
+  return 1;
+}
+
+struct RReLUupdateGradInputEval_functor
+{
+  const float negSlope_;
+
+  RReLUupdateGradInputEval_functor(float negSlope) : negSlope_(negSlope) {}
+
+  __device__ __forceinline__ void operator()(float *gradIn, float *gradOut, float* in)
+  {
+    *gradIn = (*in) <= 0 ? (*gradOut) * negSlope_ : (*gradOut);
+  }
+};
+
+struct RReLUupdateGradInputEvalIP_functor
+{
+  const float negSlope_;
+
+  RReLUupdateGradInputEvalIP_functor(float negSlope) : negSlope_(negSlope) {}
+
+  __device__ __forceinline__ void operator()(float *gradOut, float *in)
+  {
+    if (*in <= 0)
+    {
+      *gradOut = (*gradOut) * negSlope_;
+    }
+  }
+};
+
+static int cunn_RReLU_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THCudaTensor *noise = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "noise", "torch.CudaTensor");
+  double lower = luaT_getfieldchecknumber(L, 1, "lower");
+  double upper = luaT_getfieldchecknumber(L, 1, "upper");
+  int train = luaT_getfieldcheckboolean(L, 1, "train");
+  int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
+  
+  THAssert(THCudaTensor_checkGPU(state, 4, input, gradOutput, gradInput, noise));
+  
+  gradOutput = THCudaTensor_newContiguous(state, gradOutput);
+  
+  if (train && upper - lower > 1E-6)    // e.g. if upper == lower, RReLU behaves like LeakyReLU
+  {
+    // multiply the gradient by the noise tensor
+    if (inplace)
+    {
+      THCudaTensor_cmul(state, gradOutput, gradOutput, noise);
+      THCudaTensor_set(state, gradInput, gradOutput);
+    }
+    else
+    {
+      THCudaTensor_resizeAs(state, gradInput, input);
+      THCudaTensor_cmul(state, gradInput, gradOutput, noise);
+    }    
+  }
+  else
+  {
+    // use constant factor for negative input values
+    const double negSlope = (lower + upper) / 2;
+    if (inplace)
+    {
+      THCudaTensor_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
+      THCudaTensor_set(state, gradInput, gradOutput);
+    }
+    else
+    {
+      THCudaTensor_resizeAs(state, gradInput, input);
+      THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
+    }
+  }
+  
+  THCudaTensor_free(state, gradOutput);
+  return 1;
+}
+
+static const struct luaL_Reg cunn_RReLU__ [] = {
+  {"RReLU_updateOutput", cunn_RReLU_updateOutput},
+  {"RReLU_updateGradInput", cunn_RReLU_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_RReLU_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_RReLU__, "nn");
+  lua_pop(L,1);
+}
diff --git a/Sigmoid.cu b/Sigmoid.cu
new file mode 100644
index 0000000..0dac7a6
--- /dev/null
+++ b/Sigmoid.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct sigmoidupdateOutput_functor
+{
+  __device__ void operator()(float* output, const float* input) const
+  {
+    *output = 1./(1.+ exp(-*input));
+  }
+};
+
+static int cunn_Sigmoid_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
+  return 1;
+}
+
+struct sigmoidupdateGradInput_functor
+{
+  __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+  {
+    *gradInput = *gradOutput * (1.-*output) * *output;
+  }
+};
+
+static int cunn_Sigmoid_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+  THCudaTensor_resizeAs(state, gradInput, output);
+  THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_Sigmoid__ [] = {
+  {"Sigmoid_updateOutput", cunn_Sigmoid_updateOutput},
+  {"Sigmoid_updateGradInput", cunn_Sigmoid_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_Sigmoid_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_Sigmoid__, "nn");
+  lua_pop(L,1);
+}
diff --git a/SmoothL1Criterion.cu b/SmoothL1Criterion.cu
new file mode 100644
index 0000000..a0162f6
--- /dev/null
+++ b/SmoothL1Criterion.cu
@@ -0,0 +1,129 @@
+#include "utils.h"
+
+#include <thrust/fill.h>
+#include <thrust/functional.h>
+#include <thrust/device_ptr.h>
+#include <thrust/reduce.h>
+#include <thrust/inner_product.h>
+#if CUDA_VERSION >= 7000
+#include <thrust/system/cuda/execution_policy.h>
+#endif
+
+struct smoothl1_functor
+{
+  smoothl1_functor() {}
+
+  __host__ __device__ float operator()(const float& x, const float& y) const
+    {
+      float z = fabsf(x-y);
+      return z < 1.f ? 0.5f*z*z : z - 0.5f;
+  }
+};
+
+
+static int cunn_SmoothL1Criterion_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, target));
+
+  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+  luaL_argcheck(L, THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
+                "input and target need to have the same number of elements");
+
+  float sum;
+
+  long size = THCudaTensor_nElement(state, input);
+
+  input = THCudaTensor_newContiguous(state, input);
+  target = THCudaTensor_newContiguous(state, target);
+
+  thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
+  thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
+  sum = thrust::inner_product(
+#if CUDA_VERSION >= 7000
+    thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+    input_data, input_data+size, target_data, (float) 0,
+    thrust::plus<float>(), smoothl1_functor());
+
+  if(sizeAverage)
+    sum /= size;
+
+  THCudaTensor_free(state, input);
+  THCudaTensor_free(state, target);
+
+  lua_pushnumber(L, sum);
+  lua_setfield(L, 1, "output");
+
+  lua_pushnumber(L, sum);
+  return 1;
+}
+
+
+struct smoothl1_updateGradInput_functor
+{
+  const float norm;
+
+  smoothl1_updateGradInput_functor(float norm_) : norm(norm_) {}
+
+  __host__ __device__ float operator()(const float& x, const float& y) const
+    {
+      float z = x - y;
+      if(z < -1.f)
+        return -norm;
+      else if(z > 1.f)
+        return norm;
+      else
+        return norm * z;
+  }
+};
+
+static int cunn_SmoothL1Criterion_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  luaL_argcheck(L, THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
+                "input and target need to have the same number of elements");
+  THAssert(THCudaTensor_checkGPU(state, 3, input, target, gradInput));
+
+  long size = THCudaTensor_nElement(state, input);
+  float norm = (sizeAverage ? 1./size : 1.);
+
+  input = THCudaTensor_newContiguous(state, input);
+  target = THCudaTensor_newContiguous(state, target);
+
+  THCudaTensor_resizeAs(state, gradInput, input);
+
+  thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
+  thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
+  thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
+
+  thrust::transform(
+#if CUDA_VERSION >= 7000
+    thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+    input_data, input_data+size, target_data, gradInput_data,
+    smoothl1_updateGradInput_functor(norm));
+
+  THCudaTensor_free(state, input);
+  THCudaTensor_free(state, target);
+  return 1;
+}
+
+static const struct luaL_Reg cunn_SmoothL1Criterion__ [] = {
+  {"SmoothL1Criterion_updateOutput", cunn_SmoothL1Criterion_updateOutput},
+  {"SmoothL1Criterion_updateGradInput", cunn_SmoothL1Criterion_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_SmoothL1Criterion_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_SmoothL1Criterion__, "nn");
+  lua_pop(L,1);
+}
diff --git a/SoftMax.cu b/SoftMax.cu
new file mode 100644
index 0000000..364b6fb
--- /dev/null
+++ b/SoftMax.cu
@@ -0,0 +1,228 @@
+#include "utils.h"
+
+#define MINUS_LOG_THRESHOLD -18.42
+#define SOFTMAX_THREADS 128
+
+__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input,
+                                                 int nframe, int dim, int stride)
+{
+  __shared__ float buffer[SOFTMAX_THREADS+1];
+  float *input_k = input + blockIdx.x*dim*stride + blockIdx.y;
+  float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;
+
+  int i_start = threadIdx.x;
+  int i_end = dim;
+  int i_step = blockDim.x;
+
+  // max?
+  buffer[threadIdx.x] = -FLT_MAX;
+  for (int i=i_start; i<i_end; i+=i_step)
+  {
+    float z = input_k[i*stride];
+    if(buffer[threadIdx.x] < z)
+      buffer[threadIdx.x] = z;
+  }
+
+  __syncthreads();
+
+  // reduce
+  if (threadIdx.x == 0)
+  {
+    float max_k = -FLT_MAX;
+    for (int i=0; i<blockDim.x; i++)
+    {
+      if(max_k < buffer[i])
+        max_k = buffer[i];
+    }
+    buffer[SOFTMAX_THREADS] = max_k;
+  }
+
+  __syncthreads();
+
+  // sum?
+  float max_k = buffer[SOFTMAX_THREADS];
+  buffer[threadIdx.x] = 0;
+  for (int i=i_start; i<i_end; i+=i_step) {
+    float z = __expf(input_k[i*stride]-max_k);
+    buffer[threadIdx.x] += z;
+    output_k[i*stride] = z;
+  }
+
+  __syncthreads();
+
+  // reduce
+  if (threadIdx.x == 0)
+  {
+    float sum_k = 0;
+    for (int i=0; i<blockDim.x; i++)
+      sum_k += buffer[i];
+    buffer[SOFTMAX_THREADS] = sum_k;
+  }
+
+  __syncthreads();
+
+  // softmax
+  float sum_k = buffer[SOFTMAX_THREADS];
+  for (int i=i_start; i<i_end; i+=i_step)
+    output_k[i*stride] = output_k[i*stride] / sum_k;
+}
+
+
+__global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *output, float *gradOutput,
+                                                    int nframe, int dim, int stride)
+{
+  __shared__ float buffer[SOFTMAX_THREADS];
+  float *gradInput_k = gradInput + blockIdx.x*dim*stride + blockIdx.y;
+  float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;
+  float *gradOutput_k = gradOutput + blockIdx.x*dim*stride + blockIdx.y;
+
+  int i_start = threadIdx.x;
+  int i_end = dim;
+  int i_step = blockDim.x;
+
+  // sum?
+  buffer[threadIdx.x] = 0;
+  for (int i=i_start; i<i_end; i+=i_step)
+    buffer[threadIdx.x] += gradOutput_k[i*stride] * output_k[i*stride];
+
+  __syncthreads();
+
+  // reduce
+  if (threadIdx.x == 0)
+  {
+    float sum_k = 0;
+    for (int i=0; i<blockDim.x; i++)
+      sum_k += buffer[i];
+    buffer[0] = sum_k;
+  }
+
+  __syncthreads();
+
+  float sum_k = buffer[0];
+  for (int i=i_start; i<i_end; i+=i_step)
+    gradInput_k[i*stride] = output_k[i*stride] * (gradOutput_k[i*stride] - sum_k);
+}
+
+static int cunn_SoftMax_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+
+  input = THCudaTensor_newContiguous(state, input);
+  THCudaTensor_resizeAs(state, output, input);
+  long batchSize, dim, stride;
+
+  if(input->nDimension == 1)
+  {
+    batchSize = 1;
+    dim = input->size[0];
+    stride = 1;
+  }
+  else if(input->nDimension == 2)
+  {
+    batchSize = input->size[0];
+    dim = input->size[1];
+    stride = 1;
+  }
+  else if(input->nDimension == 3)
+  {
+    batchSize = 1;
+    dim = input->size[0];
+    stride = input->size[1]*input->size[2];
+  }
+  else if(input->nDimension == 4)
+  {
+    batchSize = input->size[0];
+    dim = input->size[1];
+    stride = input->size[2]*input->size[3];
+  }
+  else
+    THError("1D, 2D, 3D or 4D tensor expected");
+
+  dim3 blocks(batchSize, stride);
+  dim3 threads(SOFTMAX_THREADS);
+  cunn_SoftMax_updateOutput_kernel<<<blocks,threads,
+    0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, output),
+                                           THCudaTensor_data(state, input),
+                                           batchSize, dim, stride);
+
+  cudaError errcode = cudaGetLastError();
+  if(errcode != cudaSuccess)
+    THError(cudaGetErrorString(errcode));
+
+  THCudaTensor_free(state, input);
+  return 1;
+}
+
+static int cunn_SoftMax_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+
+  output = THCudaTensor_newContiguous(state, output);
+  gradOutput = THCudaTensor_newContiguous(state, gradOutput);
+
+  THCudaTensor_resizeAs(state, gradInput, output);
+  long batchSize, dim, stride;
+
+  if(gradInput->nDimension == 1)
+  {
+    batchSize = 1;
+    dim = gradInput->size[0];
+    stride = 1;
+  }
+  else if(gradInput->nDimension == 2)
+  {
+    batchSize = gradInput->size[0];
+    dim = gradInput->size[1];
+    stride = 1;
+  }
+  else if(gradInput->nDimension == 3)
+  {
+    batchSize = 1;
+    dim = gradInput->size[0];
+    stride = gradInput->size[1]*gradInput->size[2];
+  }
+  else if(gradInput->nDimension == 4)
+  {
+    batchSize = gradInput->size[0];
+    dim = gradInput->size[1];
+    stride = gradInput->size[2]*gradInput->size[3];
+  }
+  else
+    THError("1D, 2D, 3D or 4D tensor expected");
+
+  dim3 blocks(batchSize, stride);
+  dim3 threads(SOFTMAX_THREADS);
+  cunn_SoftMax_updateGradInput_kernel<<<blocks,threads,
+    0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, gradInput),
+                                           THCudaTensor_data(state, output),
+                                           THCudaTensor_data(state, gradOutput),
+                                           batchSize, dim, stride);
+
+  cudaError errcode = cudaGetLastError();
+  if(errcode != cudaSuccess)
+    THError(cudaGetErrorString(errcode));
+
+  THCudaTensor_free(state, gradOutput);
+  THCudaTensor_free(state, output);
+  return 1;
+}
+
+static const struct luaL_Reg cunn_SoftMax__ [] = {
+  {"SoftMax_updateOutput", cunn_SoftMax_updateOutput},
+  {"SoftMax_updateGradInput", cunn_SoftMax_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_SoftMax_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_SoftMax__, "nn");
+  lua_pop(L,1);
+}
diff --git a/SoftPlus.cu b/SoftPlus.cu
new file mode 100644
index 0000000..d522dfc
--- /dev/null
+++ b/SoftPlus.cu
@@ -0,0 +1,73 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct softPlusupdateOutput_functor
+{
+  const float threshold;
+  const float beta;
+
+  softPlusupdateOutput_functor(float threshold_, float beta_) : threshold(threshold_), beta(beta_) {}
+
+  __device__ void operator()(float* output, const float* input) const
+  {
+    float betain = beta * *input;
+    *output = ((betain) > threshold) ? *input : (1/beta) * log1p(exp(betain));
+  }
+};
+
+static int cunn_SoftPlus_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  float beta = luaT_getfieldchecknumber(L, 1, "beta");
+  float threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
+  return 1;
+}
+
+struct softPlusupdateGradInput_functor
+{
+  const float threshold;
+  const float beta;
+
+  softPlusupdateGradInput_functor(float threshold_, float beta_) : threshold(threshold_), beta(beta_) {}
+
+  __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+  {
+    float betaout = beta * *output;
+    float exp_bo = exp(betaout);
+    *gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo;
+  }
+};
+
+static int cunn_SoftPlus_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 4, input, output, gradOutput, gradInput));
+  float beta = luaT_getfieldchecknumber(L, 1, "beta");
+  float threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+
+  THCudaTensor_resizeAs(state, gradInput, output);
+  THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
+  return 1;
+}
+
+static const struct luaL_Reg cunn_SoftPlus__ [] = {
+  {"SoftPlus_updateOutput", cunn_SoftPlus_updateOutput},
+  {"SoftPlus_updateGradInput", cunn_SoftPlus_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_SoftPlus_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_SoftPlus__, "nn");
+  lua_pop(L,1);
+}
diff --git a/SoftShrink.cu b/SoftShrink.cu
new file mode 100644
index 0000000..ddb35a8
--- /dev/null
+++ b/SoftShrink.cu
@@ -0,0 +1,72 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct SoftShrinkUpdateOutput {
+  const float lambda_;
+
+  SoftShrinkUpdateOutput(float lambda): lambda_(lambda){}
+
+  __device__ __forceinline__ void operator()(float* out, float* in) {
+    float x = *in;
+    if (x > lambda_) *out = x - lambda_;
+    else if (x < -lambda_) *out = x + lambda_;
+    else *out = 0;
+  }
+};
+
+static int cunn_SoftShrink_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  double lambda = luaT_getfieldchecknumber(L, 1, "lambda");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
+  THCudaCheck(cudaGetLastError());
+  return 1;
+}
+
+struct SoftShrinkUpdateGradInput
+{
+  const float lambda_;
+
+  SoftShrinkUpdateGradInput(float lambda) : lambda_(lambda) {}
+
+  __device__ __forceinline__ void operator()(float* gradInput, float* input,
+      float* gradOutput) const {
+    float x = *input;
+    if (x > lambda_ || x < -lambda_)
+      *gradInput = *gradOutput;
+    else
+      *gradInput = 0;
+  }
+};
+
+
+static int cunn_SoftShrink_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  double lambda = luaT_getfieldchecknumber(L, 1, "lambda");
+  THAssert(THCudaTensor_checkGPU(state, 3, input, gradOutput, gradInput));
+  THCudaTensor_resizeAs(state, gradInput, input);
+  THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
+  THCudaCheck(cudaGetLastError());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_SoftShrink__ [] = {
+  {"SoftShrink_updateOutput", cunn_SoftShrink_updateOutput},
+  {"SoftShrink_updateGradInput", cunn_SoftShrink_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_SoftShrink_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_SoftShrink__, "nn");
+  lua_pop(L,1);
+}
diff --git a/Sqrt.cu b/Sqrt.cu
new file mode 100644
index 0000000..ff716c5
--- /dev/null
+++ b/Sqrt.cu
@@ -0,0 +1,61 @@
+#include "THCApply.cuh"
+#include "utils.h"
+
+struct sqrtupdateOutput_functor
+{
+  const float bias;
+
+  sqrtupdateOutput_functor(float bias_) : bias(bias_) {}
+
+  __device__ void operator()(float* output, const float* input) const
+  {
+    *output = sqrt(*input + bias);
+  }
+};
+
+static int cunn_Sqrt_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  float bias = (float) luaT_getfieldchecknumber(L,1,"eps");
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(bias));
+  return 1;
+}
+
+struct sqrtupdateGradInput_functor
+{
+  sqrtupdateGradInput_functor() {}
+
+  __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+  {
+    *gradInput = (*output == 0.0f) ? 0.0f : ((0.5f * *gradOutput) / *output);
+  }
+};
+
+static int cunn_Sqrt_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+  THCudaTensor_resizeAs(state, gradInput, output);
+  THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_Sqrt__ [] = {
+  {"Sqrt_updateOutput", cunn_Sqrt_updateOutput},
+  {"Sqrt_updateGradInput", cunn_Sqrt_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_Sqrt_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_Sqrt__, "nn");
+  lua_pop(L,1);
+}
diff --git a/Square.cu b/Square.cu
new file mode 100644
index 0000000..565af8e
--- /dev/null
+++ b/Square.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct squareupdateOutput_functor
+{
+  __device__ void operator()(float* output, const float* input) const
+  {
+    *output = *input* *input;
+  }
+};
+
+static int cunn_Square_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, squareupdateOutput_functor());
+  return 1;
+}
+
+struct squareupdateGradInput_functor
+{
+  __device__ void operator()(float* gradInput, const float* input, const float* gradOutput) const
+  {
+    *gradInput = 2.0 * *gradOutput * *input;
+  }
+};
+
+static int cunn_Square_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 3, input, gradOutput, gradInput));
+  THCudaTensor_resizeAs(state, gradInput, input);
+  THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_Square__ [] = {
+  {"Square_updateOutput", cunn_Square_updateOutput},
+  {"Square_updateGradInput", cunn_Square_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_Square_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_Square__, "nn");
+  lua_pop(L,1);
+}
diff --git a/Tanh.cu b/Tanh.cu
new file mode 100644
index 0000000..cf526ad
--- /dev/null
+++ b/Tanh.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct tanhupdateOutput_functor
+{
+  __device__ void operator()(float* output, const float* input) const
+  {
+    *output = tanh(*input);
+  }
+};
+
+static int cunn_Tanh_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+  THCudaTensor_resizeAs(state, output, input);
+  THCudaTensor_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
+  return 1;
+}
+
+struct tanhupdateGradInput_functor
+{
+  __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+  {
+    *gradInput = *gradOutput * (1 - *output * *output);
+  }
+};
+
+static int cunn_Tanh_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+  THCudaTensor_resizeAs(state, gradInput, output);
+  THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_Tanh__ [] = {
+  {"Tanh_updateOutput", cunn_Tanh_updateOutput},
+  {"Tanh_updateGradInput", cunn_Tanh_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_Tanh_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_Tanh__, "nn");
+  lua_pop(L,1);
+}
diff --git a/Threshold.cu b/Threshold.cu
new file mode 100644
index 0000000..764dd9f
--- /dev/null
+++ b/Threshold.cu
@@ -0,0 +1,117 @@
+#include "THCApply.cuh"
+#include "utils.h"
+
+struct ThresholdUpdateOutput {
+  const float threshold_;
+  const float val_;
+
+  ThresholdUpdateOutput(float threshold, float val): threshold_(threshold),
+                                                     val_(val) {}
+
+  __device__ __forceinline__ void operator()(float* out, float* in) {
+    float x = *in;
+    *out = (x > threshold_) ? x : val_;
+  }
+};
+
+// in-place variant
+struct ThresholdUpdateOutputIP {
+  const float threshold_;
+  const float val_;
+
+  ThresholdUpdateOutputIP(float threshold, float val): threshold_(threshold),
+                                                       val_(val) {}
+
+  __device__ __forceinline__ void operator()(float* x) {
+    *x = (*x > threshold_) ? *x : val_;
+  }
+};
+
+static int cunn_Threshold_updateOutput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  double val = luaT_getfieldchecknumber(L, 1, "val");
+  double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+  bool   inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+  THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+
+  if (inPlace) {
+    THCudaTensor_pointwiseApply1(state, input,
+                                 ThresholdUpdateOutputIP(threshold, val));
+    THCudaTensor_set(state, output, input);
+  } else {
+    THCudaTensor_resizeAs(state, output, input);
+    THCudaTensor_pointwiseApply2(state, output, input,
+                                 ThresholdUpdateOutput(threshold, val));
+  }
+
+  THCudaCheck(cudaGetLastError());
+  return 1;
+}
+
+struct ThresholdUpdateGradInput
+{
+  const float threshold_;
+
+  ThresholdUpdateGradInput(float threshold) : threshold_(threshold) {}
+
+  __device__ __forceinline__ void operator()(float* gradInput, float* input,
+                                             float* gradOutput) const {
+    *gradInput = (*input > threshold_) ? *gradOutput : 0;
+  }
+};
+
+struct ThresholdUpdateGradInputIP
+{
+  const float threshold_;
+
+  ThresholdUpdateGradInputIP(float threshold) : threshold_(threshold) {}
+
+  __device__ __forceinline__ void operator()(float* gradOutput,
+                                             float* input) const {
+    *gradOutput = (*input > threshold_) ? *gradOutput : 0;
+  }
+};
+
+static int cunn_Threshold_updateGradInput(lua_State *L)
+{
+  THCState *state = getCutorchState(L);
+  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+  double val = luaT_getfieldchecknumber(L, 1, "val");
+  double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+  bool   inPlace   = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+  THAssert(THCudaTensor_checkGPU(state, 4, input, output, gradInput, gradOutput));
+
+  if (inPlace) {
+    THCudaTensor_pointwiseApply2(state, gradOutput, input,
+                                 ThresholdUpdateGradInputIP(threshold));
+    THCudaTensor_set(state, gradInput, gradOutput);
+  } else {
+    THCudaTensor_resizeAs(state, gradInput, output);
+    THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
+                                 ThresholdUpdateGradInput(threshold));
+  }
+
+  THCudaCheck(cudaGetLastError());
+  return 1;
+}
+
+static const struct luaL_Reg cunn_Threshold__ [] = {
+  {"Threshold_updateOutput", cunn_Threshold_updateOutput},
+  {"Threshold_updateGradInput", cunn_Threshold_updateGradInput},
+  {NULL, NULL}
+};
+
+void cunn_Threshold_init(lua_State *L)
+{
+  luaT_pushmetatable(L, "torch.CudaTensor");
+  luaT_registeratname(L, cunn_Threshold__, "nn");
+  lua_pop(L,1);
+}