Add generic support for LogSigmoid.

This has the same logic as Sigmoid; i.e.
math is done at double precision and then
stored back at desired precision.
diff --git a/LogSigmoid.cu b/LogSigmoid.cu
index 2f56081..f008b63 100644
--- a/LogSigmoid.cu
+++ b/LogSigmoid.cu
@@ -1,35 +1,26 @@
 #include "THCUNN.h"
-#include "common.h"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
 
+template <typename T>
 struct logSigmoid_updateOutput_functor
 {
-  __device__ void operator()(float *output, const float *input) const
+  __device__ void operator()(T *output, const T *input) const
   {
-    float z = exp(-*input);
-    *output = -log(1. + z);
+    T z = exp(-*input);
+    *output = ScalarConvert<double, T>::to(-log(1. + z));
   }
 };
 
-void THNN_CudaLogSigmoid_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *buffer)
-{
-  THCUNN_assertSameGPU(state, 2, input, output);
-  THCudaTensor_resizeAs(state, output, input);
-  THC_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor());
-}
-
+template <typename T>
 struct logSigmoid_updateGradInput_functor
 {
-  __device__ void operator()(float *gradInput, const float *input, const float *gradOutput) const
+  __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const
   {
-    float z = exp(-*input);
-    *gradInput = *gradOutput * z / (1. + z);
+    T z = exp(-*input);
+    *gradInput = ScalarConvert<double, T>::to(*gradOutput * z / (1. + z));
   }
 };
 
-void THNN_CudaLogSigmoid_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput,
-  THCudaTensor *gradInput , THCudaTensor *buffer)
-{
-  THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
-  THCudaTensor_resizeAs(state, gradInput, input);
-  THC_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor());
-}
+#include "generic/LogSigmoid.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/THCHalfAutoNumerics.cuh b/THCHalfAutoNumerics.cuh
index 3c25eee..ad1f810 100644
--- a/THCHalfAutoNumerics.cuh
+++ b/THCHalfAutoNumerics.cuh
@@ -25,6 +25,10 @@
   return THCNumerics<half>::add(a, THCNumerics<half>::neg(ScalarConvert<int, half>::to(b)));
 }
 
+inline __host__ __device__ double operator-(half a, double b) {
+  return ScalarConvert<half, double>::to(a) - b;
+}
+
 inline __host__ __device__ half operator-(int a, half b) {
   return THCNumerics<half>::add(ScalarConvert<int, half>::to(a), THCNumerics<half>::neg(b));
 }
@@ -111,5 +115,13 @@
   return ScalarConvert<int, half>::to(a) / b;
 }
 
+inline __host__ __device__ double operator/(double a, half b) {
+  return a / ScalarConvert<half, double>::to(b);
+}
+
+inline __host__ __device__ double operator/(half a, double b) {
+  return ScalarConvert<half, double>::to(a) / b;
+}
+
 #endif
 #endif
diff --git a/THCUNN.h b/THCUNN.h
index ac77f61..6797dac 100644
--- a/THCUNN.h
+++ b/THCUNN.h
@@ -119,18 +119,6 @@
           double negval,
           bool inplace);
 
-TH_API void THNN_CudaLogSigmoid_updateOutput(
-          THCState *state,
-          THCudaTensor *input,
-          THCudaTensor *output,
-          THCudaTensor *buffer);
-TH_API void THNN_CudaLogSigmoid_updateGradInput(
-          THCState *state,
-          THCudaTensor *input,
-          THCudaTensor *gradOutput,
-          THCudaTensor *gradInput,
-          THCudaTensor *buffer);
-
 TH_API void THNN_CudaLogSoftMax_updateOutput(
           THCState *state,
           THCudaTensor *input,
diff --git a/generic/LogSigmoid.cu b/generic/LogSigmoid.cu
new file mode 100644
index 0000000..4a6a4c9
--- /dev/null
+++ b/generic/LogSigmoid.cu
@@ -0,0 +1,30 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/LogSigmoid.cu"
+#else
+
+#include "../common.h"
+
+void THNN_(LogSigmoid_updateOutput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *output,
+           THCTensor *buffer)
+{
+  THCUNN_assertSameGPU_generic(state, 2, input, output);
+  THCTensor_(resizeAs)(state, output, input);
+  THC_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor<real>());
+}
+
+void THNN_(LogSigmoid_updateGradInput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *gradOutput,
+           THCTensor *gradInput,
+           THCTensor *buffer)
+{
+  THCUNN_assertSameGPU_generic(state, 3, input, gradOutput, gradInput);
+  THCTensor_(resizeAs)(state, gradInput, input);
+  THC_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor<real>());
+}
+
+#endif
diff --git a/generic/THCUNN.h b/generic/THCUNN.h
index 15479c6..38b5cd5 100644
--- a/generic/THCUNN.h
+++ b/generic/THCUNN.h
@@ -30,6 +30,19 @@
                   real max_val,
                   bool inplace);
 
+TH_API void THNN_(LogSigmoid_updateOutput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *output,
+                  THCTensor *buffer);
+
+TH_API void THNN_(LogSigmoid_updateGradInput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *gradOutput,
+                  THCTensor *gradInput,
+                  THCTensor *buffer);
+
 TH_API void THNN_(Sigmoid_updateOutput)(
                   THCState *state,
                   THCTensor *input,