Move {RReLU.cu, Sigmoid.cu, SmoothL1Criterion.cu,..} to lib/THCUNN
Files moved:
RReLU.cu
Sigmoid.cu
SmoothL1Criterion.cu
SoftMax.cu
SoftPlus.cu
SoftShrink.cu
Sqrt.cu
Square.cu
Tanh.cu
Threshold.cu
diff --git a/RReLU.cu b/RReLU.cu
new file mode 100644
index 0000000..3c96103
--- /dev/null
+++ b/RReLU.cu
@@ -0,0 +1,208 @@
+#include "THCApply.cuh"
+#include "utils.h"
+#include "common.h"
+#include <curand.h>
+#include <curand_kernel.h>
+
+// copied from cutorch/lib/THC/THCTensorRandom.cu
+#define MAX_NUM_BLOCKS 64
+#define BLOCK_SIZE 256
+#define NUM_BLOCKS(n) min((int)THCCeilDiv(n, (long) BLOCK_SIZE), MAX_NUM_BLOCKS)
+
+__global__ void rreluUpdateOutputTrain(int n, curandStateMtgp32 *state,
+ float *input, float* noise, float *output, double a, double b)
+{
+ CUDA_KERNEL_LOOP(i, n)
+ {
+ if (input[i] <= 0)
+ {
+ float r = curand_uniform(&state[blockIdx.x]);
+ r = r * (b-a) + a;
+ output[i] = input[i] * r;
+ noise[i] = r;
+ }
+ else
+ {
+ output[i] = input[i];
+ noise[i] = 1;
+ }
+ }
+}
+
+struct RReLUUpdateOutputEval_functor
+{
+ const float negSlope_;
+
+ RReLUUpdateOutputEval_functor(float negSlope) : negSlope_(negSlope) {}
+
+ __device__ __forceinline__ void operator()(float* out, float* in)
+ {
+ const float x = *in;
+ const float r = x <= 0 ? negSlope_ : 1;
+ *out = x * r;
+ }
+};
+
+struct RReLUUpdateOutputEvalIP_functor
+{
+ const float negSlope_;
+
+ RReLUUpdateOutputEvalIP_functor(float negSlope) : negSlope_(negSlope) {}
+
+ __device__ __forceinline__ void operator()(float* x)
+ {
+ if (*x <= 0)
+ {
+ *x = *x * negSlope_;
+ }
+ }
+};
+
+static int cunn_RReLU_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *noise = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "noise", "torch.CudaTensor");
+ double lower = luaT_getfieldchecknumber(L, 1, "lower");
+ double upper = luaT_getfieldchecknumber(L, 1, "upper");
+ int train = luaT_getfieldcheckboolean(L, 1, "train");
+ int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ THAssert(THCudaTensor_checkGPU(state, 3, input, output, noise));
+ if (state->rngState->current_gen == NULL)
+ {
+ THError("Random number generators have not been initialized.");
+ }
+
+ if (train)
+ {
+ input = THCudaTensor_newContiguous(state, input);
+ THCudaTensor_resizeAs(state, noise, input);
+ float *input_data = THCudaTensor_data(state, input);
+ float *noise_data = THCudaTensor_data(state, noise);
+ long n = THCudaTensor_nElement(state, input);
+ if (inplace)
+ {
+ rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ n, state->rngState->current_gen->gen_states,
+ input_data, noise_data, input_data, lower, upper);
+ THCudaTensor_set(state, output, input);
+ }
+ else
+ {
+ THCudaTensor_resizeAs(state, output, input);
+ float *output_data = THCudaTensor_data(state, output);
+ rreluUpdateOutputTrain<<<NUM_BLOCKS(n), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
+ n, state->rngState->current_gen->gen_states,
+ input_data, noise_data, output_data, lower, upper);
+ }
+ THCudaTensor_free(state, input);
+ }
+ else
+ {
+ const double negSlope = (lower + upper) / 2;
+ if (inplace)
+ {
+ THCudaTensor_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
+ THCudaTensor_set(state, output, input);
+ }
+ else
+ {
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
+ }
+ }
+
+ return 1;
+}
+
+struct RReLUupdateGradInputEval_functor
+{
+ const float negSlope_;
+
+ RReLUupdateGradInputEval_functor(float negSlope) : negSlope_(negSlope) {}
+
+ __device__ __forceinline__ void operator()(float *gradIn, float *gradOut, float* in)
+ {
+ *gradIn = (*in) <= 0 ? (*gradOut) * negSlope_ : (*gradOut);
+ }
+};
+
+struct RReLUupdateGradInputEvalIP_functor
+{
+ const float negSlope_;
+
+ RReLUupdateGradInputEvalIP_functor(float negSlope) : negSlope_(negSlope) {}
+
+ __device__ __forceinline__ void operator()(float *gradOut, float *in)
+ {
+ if (*in <= 0)
+ {
+ *gradOut = (*gradOut) * negSlope_;
+ }
+ }
+};
+
+static int cunn_RReLU_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THCudaTensor *noise = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "noise", "torch.CudaTensor");
+ double lower = luaT_getfieldchecknumber(L, 1, "lower");
+ double upper = luaT_getfieldchecknumber(L, 1, "upper");
+ int train = luaT_getfieldcheckboolean(L, 1, "train");
+ int inplace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ THAssert(THCudaTensor_checkGPU(state, 4, input, gradOutput, gradInput, noise));
+
+ gradOutput = THCudaTensor_newContiguous(state, gradOutput);
+
+ if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU
+ {
+ // multiply the gradient by the noise tensor
+ if (inplace)
+ {
+ THCudaTensor_cmul(state, gradOutput, gradOutput, noise);
+ THCudaTensor_set(state, gradInput, gradOutput);
+ }
+ else
+ {
+ THCudaTensor_resizeAs(state, gradInput, input);
+ THCudaTensor_cmul(state, gradInput, gradOutput, noise);
+ }
+ }
+ else
+ {
+ // use constant factor for negative input values
+ const double negSlope = (lower + upper) / 2;
+ if (inplace)
+ {
+ THCudaTensor_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
+ THCudaTensor_set(state, gradInput, gradOutput);
+ }
+ else
+ {
+ THCudaTensor_resizeAs(state, gradInput, input);
+ THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
+ }
+ }
+
+ THCudaTensor_free(state, gradOutput);
+ return 1;
+}
+
+static const struct luaL_Reg cunn_RReLU__ [] = {
+ {"RReLU_updateOutput", cunn_RReLU_updateOutput},
+ {"RReLU_updateGradInput", cunn_RReLU_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_RReLU_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_RReLU__, "nn");
+ lua_pop(L,1);
+}
diff --git a/Sigmoid.cu b/Sigmoid.cu
new file mode 100644
index 0000000..0dac7a6
--- /dev/null
+++ b/Sigmoid.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct sigmoidupdateOutput_functor
+{
+ __device__ void operator()(float* output, const float* input) const
+ {
+ *output = 1./(1.+ exp(-*input));
+ }
+};
+
+static int cunn_Sigmoid_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
+ return 1;
+}
+
+struct sigmoidupdateGradInput_functor
+{
+ __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+ {
+ *gradInput = *gradOutput * (1.-*output) * *output;
+ }
+};
+
+static int cunn_Sigmoid_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+ THCudaTensor_resizeAs(state, gradInput, output);
+ THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_Sigmoid__ [] = {
+ {"Sigmoid_updateOutput", cunn_Sigmoid_updateOutput},
+ {"Sigmoid_updateGradInput", cunn_Sigmoid_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_Sigmoid_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_Sigmoid__, "nn");
+ lua_pop(L,1);
+}
diff --git a/SmoothL1Criterion.cu b/SmoothL1Criterion.cu
new file mode 100644
index 0000000..a0162f6
--- /dev/null
+++ b/SmoothL1Criterion.cu
@@ -0,0 +1,129 @@
+#include "utils.h"
+
+#include <thrust/fill.h>
+#include <thrust/functional.h>
+#include <thrust/device_ptr.h>
+#include <thrust/reduce.h>
+#include <thrust/inner_product.h>
+#if CUDA_VERSION >= 7000
+#include <thrust/system/cuda/execution_policy.h>
+#endif
+
+struct smoothl1_functor
+{
+ smoothl1_functor() {}
+
+ __host__ __device__ float operator()(const float& x, const float& y) const
+ {
+ float z = fabsf(x-y);
+ return z < 1.f ? 0.5f*z*z : z - 0.5f;
+ }
+};
+
+
+static int cunn_SmoothL1Criterion_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, target));
+
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ luaL_argcheck(L, THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
+ "input and target need to have the same number of elements");
+
+ float sum;
+
+ long size = THCudaTensor_nElement(state, input);
+
+ input = THCudaTensor_newContiguous(state, input);
+ target = THCudaTensor_newContiguous(state, target);
+
+ thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
+ thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
+ sum = thrust::inner_product(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ input_data, input_data+size, target_data, (float) 0,
+ thrust::plus<float>(), smoothl1_functor());
+
+ if(sizeAverage)
+ sum /= size;
+
+ THCudaTensor_free(state, input);
+ THCudaTensor_free(state, target);
+
+ lua_pushnumber(L, sum);
+ lua_setfield(L, 1, "output");
+
+ lua_pushnumber(L, sum);
+ return 1;
+}
+
+
+struct smoothl1_updateGradInput_functor
+{
+ const float norm;
+
+ smoothl1_updateGradInput_functor(float norm_) : norm(norm_) {}
+
+ __host__ __device__ float operator()(const float& x, const float& y) const
+ {
+ float z = x - y;
+ if(z < -1.f)
+ return -norm;
+ else if(z > 1.f)
+ return norm;
+ else
+ return norm * z;
+ }
+};
+
+static int cunn_SmoothL1Criterion_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *target = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ luaL_argcheck(L, THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
+ "input and target need to have the same number of elements");
+ THAssert(THCudaTensor_checkGPU(state, 3, input, target, gradInput));
+
+ long size = THCudaTensor_nElement(state, input);
+ float norm = (sizeAverage ? 1./size : 1.);
+
+ input = THCudaTensor_newContiguous(state, input);
+ target = THCudaTensor_newContiguous(state, target);
+
+ THCudaTensor_resizeAs(state, gradInput, input);
+
+ thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
+ thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
+ thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
+
+ thrust::transform(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ input_data, input_data+size, target_data, gradInput_data,
+ smoothl1_updateGradInput_functor(norm));
+
+ THCudaTensor_free(state, input);
+ THCudaTensor_free(state, target);
+ return 1;
+}
+
+static const struct luaL_Reg cunn_SmoothL1Criterion__ [] = {
+ {"SmoothL1Criterion_updateOutput", cunn_SmoothL1Criterion_updateOutput},
+ {"SmoothL1Criterion_updateGradInput", cunn_SmoothL1Criterion_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_SmoothL1Criterion_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_SmoothL1Criterion__, "nn");
+ lua_pop(L,1);
+}
diff --git a/SoftMax.cu b/SoftMax.cu
new file mode 100644
index 0000000..364b6fb
--- /dev/null
+++ b/SoftMax.cu
@@ -0,0 +1,228 @@
+#include "utils.h"
+
+#define MINUS_LOG_THRESHOLD -18.42
+#define SOFTMAX_THREADS 128
+
+__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input,
+ int nframe, int dim, int stride)
+{
+ __shared__ float buffer[SOFTMAX_THREADS+1];
+ float *input_k = input + blockIdx.x*dim*stride + blockIdx.y;
+ float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;
+
+ int i_start = threadIdx.x;
+ int i_end = dim;
+ int i_step = blockDim.x;
+
+ // max?
+ buffer[threadIdx.x] = -FLT_MAX;
+ for (int i=i_start; i<i_end; i+=i_step)
+ {
+ float z = input_k[i*stride];
+ if(buffer[threadIdx.x] < z)
+ buffer[threadIdx.x] = z;
+ }
+
+ __syncthreads();
+
+ // reduce
+ if (threadIdx.x == 0)
+ {
+ float max_k = -FLT_MAX;
+ for (int i=0; i<blockDim.x; i++)
+ {
+ if(max_k < buffer[i])
+ max_k = buffer[i];
+ }
+ buffer[SOFTMAX_THREADS] = max_k;
+ }
+
+ __syncthreads();
+
+ // sum?
+ float max_k = buffer[SOFTMAX_THREADS];
+ buffer[threadIdx.x] = 0;
+ for (int i=i_start; i<i_end; i+=i_step) {
+ float z = __expf(input_k[i*stride]-max_k);
+ buffer[threadIdx.x] += z;
+ output_k[i*stride] = z;
+ }
+
+ __syncthreads();
+
+ // reduce
+ if (threadIdx.x == 0)
+ {
+ float sum_k = 0;
+ for (int i=0; i<blockDim.x; i++)
+ sum_k += buffer[i];
+ buffer[SOFTMAX_THREADS] = sum_k;
+ }
+
+ __syncthreads();
+
+ // softmax
+ float sum_k = buffer[SOFTMAX_THREADS];
+ for (int i=i_start; i<i_end; i+=i_step)
+ output_k[i*stride] = output_k[i*stride] / sum_k;
+}
+
+
+__global__ void cunn_SoftMax_updateGradInput_kernel(float *gradInput, float *output, float *gradOutput,
+ int nframe, int dim, int stride)
+{
+ __shared__ float buffer[SOFTMAX_THREADS];
+ float *gradInput_k = gradInput + blockIdx.x*dim*stride + blockIdx.y;
+ float *output_k = output + blockIdx.x*dim*stride + blockIdx.y;
+ float *gradOutput_k = gradOutput + blockIdx.x*dim*stride + blockIdx.y;
+
+ int i_start = threadIdx.x;
+ int i_end = dim;
+ int i_step = blockDim.x;
+
+ // sum?
+ buffer[threadIdx.x] = 0;
+ for (int i=i_start; i<i_end; i+=i_step)
+ buffer[threadIdx.x] += gradOutput_k[i*stride] * output_k[i*stride];
+
+ __syncthreads();
+
+ // reduce
+ if (threadIdx.x == 0)
+ {
+ float sum_k = 0;
+ for (int i=0; i<blockDim.x; i++)
+ sum_k += buffer[i];
+ buffer[0] = sum_k;
+ }
+
+ __syncthreads();
+
+ float sum_k = buffer[0];
+ for (int i=i_start; i<i_end; i+=i_step)
+ gradInput_k[i*stride] = output_k[i*stride] * (gradOutput_k[i*stride] - sum_k);
+}
+
+static int cunn_SoftMax_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+
+ input = THCudaTensor_newContiguous(state, input);
+ THCudaTensor_resizeAs(state, output, input);
+ long batchSize, dim, stride;
+
+ if(input->nDimension == 1)
+ {
+ batchSize = 1;
+ dim = input->size[0];
+ stride = 1;
+ }
+ else if(input->nDimension == 2)
+ {
+ batchSize = input->size[0];
+ dim = input->size[1];
+ stride = 1;
+ }
+ else if(input->nDimension == 3)
+ {
+ batchSize = 1;
+ dim = input->size[0];
+ stride = input->size[1]*input->size[2];
+ }
+ else if(input->nDimension == 4)
+ {
+ batchSize = input->size[0];
+ dim = input->size[1];
+ stride = input->size[2]*input->size[3];
+ }
+ else
+ THError("1D, 2D, 3D or 4D tensor expected");
+
+ dim3 blocks(batchSize, stride);
+ dim3 threads(SOFTMAX_THREADS);
+ cunn_SoftMax_updateOutput_kernel<<<blocks,threads,
+ 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, output),
+ THCudaTensor_data(state, input),
+ batchSize, dim, stride);
+
+ cudaError errcode = cudaGetLastError();
+ if(errcode != cudaSuccess)
+ THError(cudaGetErrorString(errcode));
+
+ THCudaTensor_free(state, input);
+ return 1;
+}
+
+static int cunn_SoftMax_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+
+ output = THCudaTensor_newContiguous(state, output);
+ gradOutput = THCudaTensor_newContiguous(state, gradOutput);
+
+ THCudaTensor_resizeAs(state, gradInput, output);
+ long batchSize, dim, stride;
+
+ if(gradInput->nDimension == 1)
+ {
+ batchSize = 1;
+ dim = gradInput->size[0];
+ stride = 1;
+ }
+ else if(gradInput->nDimension == 2)
+ {
+ batchSize = gradInput->size[0];
+ dim = gradInput->size[1];
+ stride = 1;
+ }
+ else if(gradInput->nDimension == 3)
+ {
+ batchSize = 1;
+ dim = gradInput->size[0];
+ stride = gradInput->size[1]*gradInput->size[2];
+ }
+ else if(gradInput->nDimension == 4)
+ {
+ batchSize = gradInput->size[0];
+ dim = gradInput->size[1];
+ stride = gradInput->size[2]*gradInput->size[3];
+ }
+ else
+ THError("1D, 2D, 3D or 4D tensor expected");
+
+ dim3 blocks(batchSize, stride);
+ dim3 threads(SOFTMAX_THREADS);
+ cunn_SoftMax_updateGradInput_kernel<<<blocks,threads,
+ 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, gradInput),
+ THCudaTensor_data(state, output),
+ THCudaTensor_data(state, gradOutput),
+ batchSize, dim, stride);
+
+ cudaError errcode = cudaGetLastError();
+ if(errcode != cudaSuccess)
+ THError(cudaGetErrorString(errcode));
+
+ THCudaTensor_free(state, gradOutput);
+ THCudaTensor_free(state, output);
+ return 1;
+}
+
+static const struct luaL_Reg cunn_SoftMax__ [] = {
+ {"SoftMax_updateOutput", cunn_SoftMax_updateOutput},
+ {"SoftMax_updateGradInput", cunn_SoftMax_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_SoftMax_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_SoftMax__, "nn");
+ lua_pop(L,1);
+}
diff --git a/SoftPlus.cu b/SoftPlus.cu
new file mode 100644
index 0000000..d522dfc
--- /dev/null
+++ b/SoftPlus.cu
@@ -0,0 +1,73 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct softPlusupdateOutput_functor
+{
+ const float threshold;
+ const float beta;
+
+ softPlusupdateOutput_functor(float threshold_, float beta_) : threshold(threshold_), beta(beta_) {}
+
+ __device__ void operator()(float* output, const float* input) const
+ {
+ float betain = beta * *input;
+ *output = ((betain) > threshold) ? *input : (1/beta) * log1p(exp(betain));
+ }
+};
+
+static int cunn_SoftPlus_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ float beta = luaT_getfieldchecknumber(L, 1, "beta");
+ float threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
+ return 1;
+}
+
+struct softPlusupdateGradInput_functor
+{
+ const float threshold;
+ const float beta;
+
+ softPlusupdateGradInput_functor(float threshold_, float beta_) : threshold(threshold_), beta(beta_) {}
+
+ __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+ {
+ float betaout = beta * *output;
+ float exp_bo = exp(betaout);
+ *gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo;
+ }
+};
+
+static int cunn_SoftPlus_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 4, input, output, gradOutput, gradInput));
+ float beta = luaT_getfieldchecknumber(L, 1, "beta");
+ float threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+
+ THCudaTensor_resizeAs(state, gradInput, output);
+ THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
+ return 1;
+}
+
+static const struct luaL_Reg cunn_SoftPlus__ [] = {
+ {"SoftPlus_updateOutput", cunn_SoftPlus_updateOutput},
+ {"SoftPlus_updateGradInput", cunn_SoftPlus_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_SoftPlus_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_SoftPlus__, "nn");
+ lua_pop(L,1);
+}
diff --git a/SoftShrink.cu b/SoftShrink.cu
new file mode 100644
index 0000000..ddb35a8
--- /dev/null
+++ b/SoftShrink.cu
@@ -0,0 +1,72 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct SoftShrinkUpdateOutput {
+ const float lambda_;
+
+ SoftShrinkUpdateOutput(float lambda): lambda_(lambda){}
+
+ __device__ __forceinline__ void operator()(float* out, float* in) {
+ float x = *in;
+ if (x > lambda_) *out = x - lambda_;
+ else if (x < -lambda_) *out = x + lambda_;
+ else *out = 0;
+ }
+};
+
+static int cunn_SoftShrink_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ double lambda = luaT_getfieldchecknumber(L, 1, "lambda");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
+ THCudaCheck(cudaGetLastError());
+ return 1;
+}
+
+struct SoftShrinkUpdateGradInput
+{
+ const float lambda_;
+
+ SoftShrinkUpdateGradInput(float lambda) : lambda_(lambda) {}
+
+ __device__ __forceinline__ void operator()(float* gradInput, float* input,
+ float* gradOutput) const {
+ float x = *input;
+ if (x > lambda_ || x < -lambda_)
+ *gradInput = *gradOutput;
+ else
+ *gradInput = 0;
+ }
+};
+
+
+static int cunn_SoftShrink_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ double lambda = luaT_getfieldchecknumber(L, 1, "lambda");
+ THAssert(THCudaTensor_checkGPU(state, 3, input, gradOutput, gradInput));
+ THCudaTensor_resizeAs(state, gradInput, input);
+ THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
+ THCudaCheck(cudaGetLastError());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_SoftShrink__ [] = {
+ {"SoftShrink_updateOutput", cunn_SoftShrink_updateOutput},
+ {"SoftShrink_updateGradInput", cunn_SoftShrink_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_SoftShrink_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_SoftShrink__, "nn");
+ lua_pop(L,1);
+}
diff --git a/Sqrt.cu b/Sqrt.cu
new file mode 100644
index 0000000..ff716c5
--- /dev/null
+++ b/Sqrt.cu
@@ -0,0 +1,61 @@
+#include "THCApply.cuh"
+#include "utils.h"
+
+struct sqrtupdateOutput_functor
+{
+ const float bias;
+
+ sqrtupdateOutput_functor(float bias_) : bias(bias_) {}
+
+ __device__ void operator()(float* output, const float* input) const
+ {
+ *output = sqrt(*input + bias);
+ }
+};
+
+static int cunn_Sqrt_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ float bias = (float) luaT_getfieldchecknumber(L,1,"eps");
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(bias));
+ return 1;
+}
+
+struct sqrtupdateGradInput_functor
+{
+ sqrtupdateGradInput_functor() {}
+
+ __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+ {
+ *gradInput = (*output == 0.0f) ? 0.0f : ((0.5f * *gradOutput) / *output);
+ }
+};
+
+static int cunn_Sqrt_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+ THCudaTensor_resizeAs(state, gradInput, output);
+ THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_Sqrt__ [] = {
+ {"Sqrt_updateOutput", cunn_Sqrt_updateOutput},
+ {"Sqrt_updateGradInput", cunn_Sqrt_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_Sqrt_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_Sqrt__, "nn");
+ lua_pop(L,1);
+}
diff --git a/Square.cu b/Square.cu
new file mode 100644
index 0000000..565af8e
--- /dev/null
+++ b/Square.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct squareupdateOutput_functor
+{
+ __device__ void operator()(float* output, const float* input) const
+ {
+ *output = *input* *input;
+ }
+};
+
+static int cunn_Square_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, squareupdateOutput_functor());
+ return 1;
+}
+
+struct squareupdateGradInput_functor
+{
+ __device__ void operator()(float* gradInput, const float* input, const float* gradOutput) const
+ {
+ *gradInput = 2.0 * *gradOutput * *input;
+ }
+};
+
+static int cunn_Square_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 3, input, gradOutput, gradInput));
+ THCudaTensor_resizeAs(state, gradInput, input);
+ THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_Square__ [] = {
+ {"Square_updateOutput", cunn_Square_updateOutput},
+ {"Square_updateGradInput", cunn_Square_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_Square_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_Square__, "nn");
+ lua_pop(L,1);
+}
diff --git a/Tanh.cu b/Tanh.cu
new file mode 100644
index 0000000..cf526ad
--- /dev/null
+++ b/Tanh.cu
@@ -0,0 +1,54 @@
+#include "utils.h"
+#include "THCApply.cuh"
+
+struct tanhupdateOutput_functor
+{
+ __device__ void operator()(float* output, const float* input) const
+ {
+ *output = tanh(*input);
+ }
+};
+
+static int cunn_Tanh_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
+ return 1;
+}
+
+struct tanhupdateGradInput_functor
+{
+ __device__ void operator()(float* gradInput, const float* output, const float* gradOutput) const
+ {
+ *gradInput = *gradOutput * (1 - *output * *output);
+ }
+};
+
+static int cunn_Tanh_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
+ THCudaTensor_resizeAs(state, gradInput, output);
+ THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_Tanh__ [] = {
+ {"Tanh_updateOutput", cunn_Tanh_updateOutput},
+ {"Tanh_updateGradInput", cunn_Tanh_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_Tanh_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_Tanh__, "nn");
+ lua_pop(L,1);
+}
diff --git a/Threshold.cu b/Threshold.cu
new file mode 100644
index 0000000..764dd9f
--- /dev/null
+++ b/Threshold.cu
@@ -0,0 +1,117 @@
+#include "THCApply.cuh"
+#include "utils.h"
+
+struct ThresholdUpdateOutput {
+ const float threshold_;
+ const float val_;
+
+ ThresholdUpdateOutput(float threshold, float val): threshold_(threshold),
+ val_(val) {}
+
+ __device__ __forceinline__ void operator()(float* out, float* in) {
+ float x = *in;
+ *out = (x > threshold_) ? x : val_;
+ }
+};
+
+// in-place variant
+struct ThresholdUpdateOutputIP {
+ const float threshold_;
+ const float val_;
+
+ ThresholdUpdateOutputIP(float threshold, float val): threshold_(threshold),
+ val_(val) {}
+
+ __device__ __forceinline__ void operator()(float* x) {
+ *x = (*x > threshold_) ? *x : val_;
+ }
+};
+
+static int cunn_Threshold_updateOutput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ double val = luaT_getfieldchecknumber(L, 1, "val");
+ double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+ bool inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ THAssert(THCudaTensor_checkGPU(state, 2, input, output));
+
+ if (inPlace) {
+ THCudaTensor_pointwiseApply1(state, input,
+ ThresholdUpdateOutputIP(threshold, val));
+ THCudaTensor_set(state, output, input);
+ } else {
+ THCudaTensor_resizeAs(state, output, input);
+ THCudaTensor_pointwiseApply2(state, output, input,
+ ThresholdUpdateOutput(threshold, val));
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return 1;
+}
+
+struct ThresholdUpdateGradInput
+{
+ const float threshold_;
+
+ ThresholdUpdateGradInput(float threshold) : threshold_(threshold) {}
+
+ __device__ __forceinline__ void operator()(float* gradInput, float* input,
+ float* gradOutput) const {
+ *gradInput = (*input > threshold_) ? *gradOutput : 0;
+ }
+};
+
+struct ThresholdUpdateGradInputIP
+{
+ const float threshold_;
+
+ ThresholdUpdateGradInputIP(float threshold) : threshold_(threshold) {}
+
+ __device__ __forceinline__ void operator()(float* gradOutput,
+ float* input) const {
+ *gradOutput = (*input > threshold_) ? *gradOutput : 0;
+ }
+};
+
+static int cunn_Threshold_updateGradInput(lua_State *L)
+{
+ THCState *state = getCutorchState(L);
+ THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
+ THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
+ THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
+ THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
+ double val = luaT_getfieldchecknumber(L, 1, "val");
+ double threshold = luaT_getfieldchecknumber(L, 1, "threshold");
+ bool inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ THAssert(THCudaTensor_checkGPU(state, 4, input, output, gradInput, gradOutput));
+
+ if (inPlace) {
+ THCudaTensor_pointwiseApply2(state, gradOutput, input,
+ ThresholdUpdateGradInputIP(threshold));
+ THCudaTensor_set(state, gradInput, gradOutput);
+ } else {
+ THCudaTensor_resizeAs(state, gradInput, output);
+ THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
+ ThresholdUpdateGradInput(threshold));
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return 1;
+}
+
+static const struct luaL_Reg cunn_Threshold__ [] = {
+ {"Threshold_updateOutput", cunn_Threshold_updateOutput},
+ {"Threshold_updateGradInput", cunn_Threshold_updateGradInput},
+ {NULL, NULL}
+};
+
+void cunn_Threshold_init(lua_State *L)
+{
+ luaT_pushmetatable(L, "torch.CudaTensor");
+ luaT_registeratname(L, cunn_Threshold__, "nn");
+ lua_pop(L,1);
+}