Spatial Depthwise Convolution on the GPU (#3057)

* THCUNN Skeleton for Depthwise Convolution port

* implement Depthwise Convolution CUDA Kernels (handles weight parameter only, not bias)

* working kernels and bindings for forward + backward for base conv, and integration

* add support for padding

* strides for weight kernel

* dilation for weight gradient, enable for others

* add support for depthwise multiplier

* remove old depthwise conv

* rename to SpatialDepthwiseConvolution

* clean up depthwise code, add shape asserts, more constrained thread count for accgradparams

* add bias for forward for depthwise conv

* add grad_bias, move bias for forward to CUDA

* fix eligibility test to guard against transposed, properly identify depth multiplier

* add basic unit test; make depthwise conv take priority over cudnn when appropriate

* add tests for depthwise permutations

* make cuda kernels calculate positions using mul instead of div

* remove unnecessary samegpu requirement

* use accreal, test for double type

* use THAssert instead of assert

* rename to is_depthwise

* half prec support for depthwise

* make certain computation more pythonic

* flake8
diff --git a/test/test_nn.py b/test/test_nn.py
index a08bc32..81867c4 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1942,6 +1942,49 @@
             self.assertEqual(m.weight.grad.data,
                              torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0))
 
+    # Very similar to test_Conv2d_naive_groups but with special care to handle
+    # the number of groups == number of input channels
+    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
+    def test_Conv2d_depthwise_naive_groups(self):
+        types = [torch.cuda.FloatTensor, torch.cuda.DoubleTensor,
+                 torch.cuda.HalfTensor]
+        precs = [1e-5, 1e-5, 1e-2]
+        for tp, prec in zip(types, precs):
+            for depth_multiplier in [1, 2]:
+                m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).type(tp)
+                i = Variable(torch.randn(2, 2, 6, 6).type(tp), requires_grad=True)
+                output = m(i)
+                grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4).type(tp)
+                output.backward(grad_output)
+
+                offset = 1 * depth_multiplier
+
+                m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).type(tp)
+                m1.weight.data = m.weight.data[:offset].clone()
+                m1.bias.data = m.bias.data[:offset].clone()
+                i1 = Variable(i.data[:, :1].contiguous(), requires_grad=True)
+                output1 = m1(i1)
+                output1.backward(grad_output[:, :offset].contiguous())
+
+                m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).type(tp)
+                m2.weight.data.copy_(m.weight.data[offset:])
+                m2.bias.data.copy_(m.bias.data[offset:])
+                i2 = Variable(i.data[:, 1:].contiguous(), requires_grad=True)
+                output2 = m2(i2)
+                output2.backward(grad_output[:, offset:].contiguous())
+
+                self.assertEqual(output, torch.cat([output1, output2], 1),
+                                 prec=prec)
+                self.assertEqual(i.grad.data,
+                                 torch.cat([i1.grad.data, i2.grad.data], 1),
+                                 prec=prec)
+                self.assertEqual(m.bias.grad.data,
+                                 torch.cat([m1.bias.grad.data,
+                                            m2.bias.grad.data], 0), prec=prec)
+                self.assertEqual(m.weight.grad.data,
+                                 torch.cat([m1.weight.grad.data,
+                                            m2.weight.grad.data], 0), prec=prec)
+
     def test_MaxUnpool2d_output_size(self):
         m = nn.MaxPool2d(3, stride=2, return_indices=True)
         mu = nn.MaxUnpool2d(3, stride=2)
@@ -3830,6 +3873,31 @@
         check_gradgrad=False,
     ),
     dict(
+        fullname='Conv2d_depthwise',
+        constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
+        input_size=(2, 4, 6, 6),
+    ),
+    dict(
+        fullname='Conv2d_depthwise_with_multiplier',
+        constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
+        input_size=(2, 4, 6, 6),
+    ),
+    dict(
+        fullname='Conv2d_depthwise_strided',
+        constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
+        input_size=(2, 4, 6, 6),
+    ),
+    dict(
+        fullname='Conv2d_depthwise_padded',
+        constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
+        input_size=(2, 4, 6, 6),
+    ),
+    dict(
+        fullname='Conv2d_depthwise_dilated',
+        constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
+        input_size=(2, 4, 5, 5),
+    ),
+    dict(
         module_name='MaxPool2d',
         constructor_args=((3, 3), (2, 2), (1, 1)),
         input_size=(1, 3, 7, 7),
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
index 4f4a7b5..d933443 100644
--- a/torch/csrc/autograd/functions/convolution.cpp
+++ b/torch/csrc/autograd/functions/convolution.cpp
@@ -54,6 +54,14 @@
   return is_dilated;
 }
 
+auto ConvParams::is_padded() const -> bool {
+  bool is_padded = false;
+  for (int p : padding) {
+    is_padded |= (p != 0);
+  }
+  return is_padded;
+}
+
 auto ConvParams::is_output_padding_neg() const -> bool {
   bool is_non_neg = false;
   for (int p : output_padding) {
@@ -119,6 +127,19 @@
   return false;
 }
 
+// We currently only have depthwise support for the case where groups ==
+// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
+// a depthwise multiplier)
+auto ConvParams::is_depthwise(
+        const at::Tensor& input, const at::Tensor& weight, int groups) const -> bool {
+  return input.type().isCuda() &&
+         !transposed &&
+         input.ndimension() == 4 &&
+         input.size(1) == groups &&
+         groups > 1 && // no point if there is only a single group
+         weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels
+}
+
 std::string ConvForward::name() { return "ConvForward"; }
 
 auto ConvForward::output_size(at::Tensor& input, at::Tensor& weight) const -> std::vector<int64_t> {
@@ -220,6 +241,14 @@
   return result;
 }
 
+static std::vector<int64_t> vecToInt64(const std::vector<int>& src) {
+  std::vector<int64_t> res(src.size());
+  for (size_t i = 0; i < src.size(); i++) {
+    res[i] = static_cast<int64_t>(src[i]);
+  }
+  return res;
+}
+
 static at::Tensor cat(const tensor_list& tensors, int dim) {
   int num_inputs = tensors.size();
   if (num_inputs == 0) {
@@ -231,7 +260,6 @@
   return output;
 }
 
-
 // ConvForward implementation
 
 auto ConvForward::apply(const variable_list& inputs) -> variable_list {
@@ -260,7 +288,16 @@
   tensor_list ones(groups);
   std::unique_ptr<Convolution> convolution;
 
-  if (use_cudnn(input)) {
+  if (is_depthwise(input, weight, groups)) {
+      /* output.resize_(output_size(input, weight)); */
+
+      auto kernel_size = weight.sizes().slice(2);
+      auto stride = vecToInt64(this->stride);
+      auto padding = vecToInt64(this->padding);
+      auto dilation = vecToInt64(this->dilation);
+
+      output = at::conv_depthwise2d_forward(input, weight, kernel_size, bias, stride, padding, dilation);
+  } else if (use_cudnn(input)) {
 #ifdef WITH_CUDNN
     if (input.type().ID() != weight.type().ID()){
       std::stringstream ss;
@@ -325,6 +362,14 @@
   });
 };
 
+// For Convolution strategies that don't implicitly handle grad_bias, we add a helper
+// function here to perform it using simple Tensor operators
+static at::Tensor compute_grad_bias(const at::Tensor& grad_output) {
+  // grad_output is in N, C, H, W, we re-shape and make contiguous
+  at::Tensor transposed = grad_output.transpose(0, 1).contiguous();
+  // sum across all of the channels and add to grad_bias
+  return transposed.view({transposed.size(0), -1}).sum(1);
+}
 
 // ConvBackward implementation
 
@@ -354,6 +399,8 @@
     grad_output = view4d(grad_output);
   }
 
+
+  bool use_depthwise = this->is_depthwise(input, weight, groups);
   bool use_cudnn = this->use_cudnn(input);
 
   at::Tensor grad_input;
@@ -366,7 +413,23 @@
     should_compute_output(2) && bias.defined(),
   };
 
-  if (use_cudnn) {
+  if (use_depthwise) {
+    if (output_mask[0] || output_mask[1]) {
+      auto kernel_size = weight.sizes().slice(2);
+      auto stride = vecToInt64(this->stride);
+      auto padding = vecToInt64(this->padding);
+      auto dilation = vecToInt64(this->dilation);
+
+      std::tie(grad_input, grad_weight) = at::conv_depthwise2d_backward(
+          grad_output, input, weight, kernel_size, stride, padding, dilation,
+          {output_mask[0], output_mask[1]});
+      }
+
+      // THCUNN implementation does not handle bias, so we do it ourselves
+      if (output_mask[2]) {
+        grad_bias = compute_grad_bias(grad_output);
+      }
+    } else if (use_cudnn) {
 #ifdef WITH_CUDNN
     if (output_mask[0]) {
       grad_input = input.type().tensor();
@@ -648,14 +711,6 @@
   grad_output_.data.reset();
 }
 
-static std::vector<int64_t> vecToInt64(const std::vector<int>& src) {
-  std::vector<int64_t> res(src.size());
-  for (size_t i = 0; i < src.size(); i++) {
-    res[i] = static_cast<int64_t>(src[i]);
-  }
-  return res;
-}
-
 // Forward and backward functions for Tensor
 
 static at::Tensor compute_output(
@@ -791,10 +846,7 @@
           }
 
           if (output_mask[2]) {
-            // grad_output is in N, C, H, W, we re-shape and make contiguous
-            at::Tensor transposed = grad_output.transpose(0, 1).contiguous();
-            // sum across all of the channels and add to grad_bias
-            grad_bias = transposed.view({transposed.size(0), -1}).sum(1);
+            grad_bias = compute_grad_bias(grad_output);
           }
 
           return std::make_tuple(grad_input, grad_weight, grad_bias);
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h
index be5f6ae..f8ba1bc 100644
--- a/torch/csrc/autograd/functions/convolution.h
+++ b/torch/csrc/autograd/functions/convolution.h
@@ -34,12 +34,14 @@
 
   bool is_strided() const;
   bool is_dilated() const;
+  bool is_padded() const;
   bool is_output_padding_neg() const;
   bool is_output_padding_big() const;
   bool is_padding_neg() const;
   void view1d_as_2d();
   bool use_cudnn(const at::Tensor& input) const;
   bool use_nnpack(const at::Tensor& input) const;
+  bool is_depthwise(const at::Tensor& input, const at::Tensor& weight, int groups) const;
 };
 
 struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
diff --git a/torch/lib/ATen/gen.py b/torch/lib/ATen/gen.py
index 48de0b2..3093060 100644
--- a/torch/lib/ATen/gen.py
+++ b/torch/lib/ATen/gen.py
@@ -222,6 +222,7 @@
 declarations = [d
                 for file in cwrap_files
                 for d in cwrap_parser.parse(file)]
+print(nn_files)
 declarations += nn_parse.run(nn_files)
 declarations = preprocess_declarations.run(declarations)
 for fname, env in generators.items():
diff --git a/torch/lib/ATen/nn.yaml b/torch/lib/ATen/nn.yaml
index 6d23585..3ef6b76 100644
--- a/torch/lib/ATen/nn.yaml
+++ b/torch/lib/ATen/nn.yaml
@@ -138,6 +138,10 @@
   cname: SpatialConvolutionMM
   buffers: [finput, fgrad_input]
 
+- name: conv_depthwise2d(Tensor input, Tensor weight, IntList[2] kernel_size, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] dilation=1)
+  cname: SpatialDepthwiseConvolution
+  buffers: []
+
 - name: conv3d(Tensor input, Tensor weight, IntList[3] kernel_size, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0)
   cname: VolumetricConvolutionMM
   buffers: [finput, fgrad_input]
diff --git a/torch/lib/ATen/nn_parse.py b/torch/lib/ATen/nn_parse.py
index 7267531..95cee94 100644
--- a/torch/lib/ATen/nn_parse.py
+++ b/torch/lib/ATen/nn_parse.py
@@ -54,6 +54,15 @@
     cname = thnn_function.name
     output_args = []
 
+    # function_wrapper expects everything in a declaration to be in
+    # the base type (i.e. THTensor*), but if we pull a THCUNN only
+    # implementation, it will have THCTensor* as the arg type. So we
+    # strip the THC here before returning
+    def map_to_th_type(t):
+        if t.startswith('THC'):
+            t = t.replace('THC', 'TH')
+        return t
+
     def is_output_arg(arg_name, func_name):
         if arg_name == 'output' and 'updateOutput' in cname:
             return True
@@ -68,7 +77,7 @@
         name = arg.name
         if is_output_arg(name, cname):
             desc = {
-                'type': arg.type,
+                'type': map_to_th_type(arg.type),
                 'name': camel_to_snake(name),
                 'output': True,
             }
@@ -206,7 +215,7 @@
     return {
         'mode': 'NN',
         'name': name,
-        'types': ['Float', 'Double'],
+        'types': ['Float', 'Double', 'Half'],  # Half will be stripped for CPU backend
         'arguments': arguments,
         'return': get_return(arguments),
         'buffers': buffers,
diff --git a/torch/lib/THC/CMakeLists.txt b/torch/lib/THC/CMakeLists.txt
index a51586a..a46ec88 100644
--- a/torch/lib/THC/CMakeLists.txt
+++ b/torch/lib/THC/CMakeLists.txt
@@ -286,6 +286,7 @@
           THCReduce.cuh
           THCReduceAll.cuh
           THCReduceApplyUtils.cuh
+          THCTensorMathReduce.cuh
           THCAsmUtils.cuh
           THCAtomics.cuh
           THCScanUtils.cuh
diff --git a/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu b/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu
deleted file mode 100644
index 53c92ac..0000000
--- a/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-#include "THCUNN.h"
-#include "common.h"
-#include "im2col.h"
-
-#include "THCHalf.h"
-#include "THCHalfAutoNumerics.cuh"
-
-#include "generic/SpatialDepthWiseConvolution.cu"
-#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu b/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu
new file mode 100644
index 0000000..3be6f8c
--- /dev/null
+++ b/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu
@@ -0,0 +1,205 @@
+// updateOutput, updateGradInput Kernels ported from Sergey Zagoruyko's pyinn, which itself was a
+// port from Caffe
+
+#include "THCUNN.h"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCNumerics.cuh"
+#include "THCReduceApplyUtils.cuh"
+#include "THCSortUtils.cuh"
+#include "THCTensorMathReduce.cuh"
+#include "SharedMem.cuh"
+#include "common.h"
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionUpdateOutput(
+    const THCDeviceTensor<T, 4> input,
+    THCDeviceTensor<T, 4> output,
+    const THCDeviceTensor<T, 4> weight,
+    const THCDeviceTensor<T, 1> bias,
+    bool biasEnabled,
+    IndexType totalElements,
+    const int outputChannels,
+    const int depthwiseMultiplier,
+    const int inputWidth, const int inputHeight,
+    const int outputWidth, const int outputHeight,
+    const int kernelWidth, const int kernelHeight,
+    const int strideWidth, const int strideHeight,
+    const int padWidth, const int padHeight,
+    const int dilationWidth, const int dilationHeight)
+{
+  const int channelStride = outputHeight * outputWidth;
+  const int batchStride = outputChannels * channelStride;
+
+  for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
+       linearIndex < totalElements;
+       linearIndex += gridDim.x * blockDim.x) {
+
+    const int n = linearIndex / batchStride;
+    const int c = (linearIndex / channelStride) % outputChannels;
+    const int h = (linearIndex / outputWidth) % outputHeight;
+    const int w = linearIndex % outputWidth;
+
+    const int inputChannel = c / depthwiseMultiplier;
+    const int inputChannels = outputChannels / depthwiseMultiplier;
+
+    int weightOffset = c * kernelHeight * kernelWidth;
+
+    AccT value = biasEnabled ? ScalarConvert<T, AccT>::to(bias.data()[c]) : ScalarConvert<int, AccT>::to(0);
+    for (int kH = 0; kH < kernelHeight; ++kH) {
+      for (int kW = 0; kW < kernelWidth; ++kW) {
+        const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
+        const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
+
+        if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
+          const IndexType offset = ((n * inputChannels + inputChannel) * inputHeight + h_in) *
+                                    inputWidth + w_in;
+          value = THCNumerics<AccT>::add(
+            value,
+            THCNumerics<AccT>::mul(
+              ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
+              ScalarConvert<T, AccT>::to(input.data()[offset])));
+        }
+        ++weightOffset;
+      }
+    }
+    output.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
+  }
+}
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionUpdateGradInput(
+    const THCDeviceTensor<T, 4> gradOutput,
+    THCDeviceTensor<T, 4> gradInput,
+    const THCDeviceTensor<T, 4> weight,
+    IndexType totalElements,
+    const int inputChannels,
+    const int depthwiseMultiplier,
+    const int outputChannels,
+    const int inputWidth, const int inputHeight,
+    const int outputWidth, const int outputHeight,
+    const int kernelWidth, const int kernelHeight,
+    const int strideWidth, const int strideHeight,
+    const int padWidth, const int padHeight,
+    const int dilationWidth, const int dilationHeight)
+{
+  const int channelStride = inputHeight * inputWidth;
+  const int batchStride = inputChannels * channelStride;
+
+  for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
+       linearIndex < totalElements;
+       linearIndex += gridDim.x * blockDim.x) {
+
+    const int n = linearIndex / batchStride;
+    const int c = (linearIndex / channelStride) % inputChannels;
+    const int h = (linearIndex / inputWidth) % inputHeight;
+    const int w = linearIndex % inputWidth;
+
+    AccT value = ScalarConvert<int, AccT>::to(0);
+    for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
+      int och = (c * depthwiseMultiplier) + multiplier;
+      int weightOffset = och * kernelHeight * kernelWidth;
+      for (int kh = 0; kh < kernelHeight; ++kh) {
+        for (int kw = 0; kw < kernelWidth; ++kw) {
+          const int h_out_s = h + padHeight - kh * dilationHeight;
+          const int w_out_s = w + padWidth - kw * dilationWidth;
+
+          if (((h_out_s % strideHeight) == 0) && ((w_out_s % strideWidth) == 0)) {
+            const int h_out = h_out_s / strideHeight;
+            const int w_out = w_out_s / strideWidth;
+
+            if ((h_out >= 0) && (h_out < outputHeight)
+                  && (w_out >= 0) && (w_out < outputWidth)) {
+
+              const int offset = ((n * outputChannels + och) * outputHeight + h_out)
+                    * outputWidth + w_out;
+              value = THCNumerics<AccT>::add(
+                value,
+                THCNumerics<AccT>::mul(
+                  ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
+                  ScalarConvert<T, AccT>::to(gradOutput.data()[offset])));
+            }
+          }
+          ++weightOffset;
+        }
+      }
+    }
+    gradInput.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
+  }
+}
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionAccGradParameters(
+    const THCDeviceTensor<T, 4> gradOutput,
+    const THCDeviceTensor<T, 4> input,
+    THCDeviceTensor<T, 4> gradWeight,
+    const int batchSize,
+    const int inputChannels,
+    const int kernelChannels,
+    const int depthwiseMultiplier,
+    IndexType blockElements,
+    const int inputWidth, const int inputHeight,
+    const int outputWidth, const int outputHeight,
+    const int kernelWidth, const int kernelHeight,
+    const int strideWidth, const int strideHeight,
+    const int padWidth, const int padHeight,
+    const int dilationWidth, const int dilationHeight)
+{
+  const int channelStride = kernelWidth * kernelHeight;
+
+  // Have to use a statically typed Shared Memory pointer
+  SharedMem<AccT> smem;
+
+  // Each Block is responsible for accumulating over a permutation of
+  // (channels x kH x kW), use blockIdx to determine which one
+  int bidx = blockIdx.x;
+  int kW = bidx % kernelWidth;
+  int kH = (bidx / kernelWidth) % kernelHeight;
+  int ch = (bidx / channelStride) % kernelChannels;
+
+  // Need to calculate which input channel is associated with this filter
+  // channel
+  int inputCh = ch / depthwiseMultiplier;
+
+  AccT grad = ScalarConvert<float, AccT>::to(0.0);
+
+  // Block-stride loop over the number of elements we need to reduce
+  for (IndexType idx = threadIdx.x; idx < blockElements; idx += blockDim.x) {
+    // Need to calculate the following: batch position, and offset into the gradOutput
+    // in height, and width. We can intuit the corresponding position in the input from
+    // the other parameters we have
+    int go_w_offset = idx % outputWidth;
+    int go_h_offset = (idx / outputWidth) % outputHeight;
+    int batch = (idx / outputWidth / outputHeight) % batchSize;
+
+    int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
+    int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
+
+    if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
+      int inputOffset = ((batch * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
+      int outputOffset = ((batch * kernelChannels + ch) * outputHeight + go_h_offset) * outputWidth + go_w_offset;
+      grad = THCNumerics<AccT>::add(
+          grad,
+          THCNumerics<AccT>::mul(
+            ScalarConvert<T, AccT>::to(input.data()[inputOffset]),
+            ScalarConvert<T, AccT>::to(gradOutput.data()[outputOffset])));
+    }
+  }
+  __syncthreads();
+
+  // At this point each thread in the block has a local gradient, which we need to
+  // accumulate prior to writing the global value
+  AccT *buf = smem.getPointer();
+  AccT tval = reduceBlock<AccT, ReduceAdd<AccT, AccT>>(
+      buf, blockDim.x, grad, ReduceAdd<AccT, AccT>(), ScalarConvert<float, AccT>::to(0));
+
+  // After reduction, first thread in the block has the gradient, so its responsible
+  // for writing it to gradWeight
+  if (threadIdx.x == 0) {
+    int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
+    gradWeight.data()[weightOffset] = ScalarConvert<AccT, T>::to(tval);
+  }
+}
+
+#include "generic/SpatialDepthwiseConvolution.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu b/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
deleted file mode 100644
index 1f8391d..0000000
--- a/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
+++ /dev/null
@@ -1,652 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "generic/SpatialDepthWiseConvolution.cu"
-#else
-
-static inline void THNN_(SpatialDepthWiseConvolution_shapeCheck)(
-                         THCState *state,
-                         THCTensor *input, THCTensor *gradOutput,
-                         THCTensor *weight, THCTensor *bias,
-                         int kH, int kW, int dH, int dW, int padH, int padW) {
-  THArgCheck(kW > 0 && kH > 0, 9,
-             "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
-  THArgCheck(dW > 0 && dH > 0, 11,
-             "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
-  THCUNN_argCheck(state, weight->nDimension == 4, 5, weight,
-                  "2D or 4D weight tensor expected, but got: %s");
-
-  if (bias != NULL) {
-    THCUNN_check_dim_size(state, bias, 2, 0, weight->size[0]);
-    THCUNN_check_dim_size(state, bias, 2, 1, weight->size[1]);
-  }
-
-  int ndim = input->nDimension;
-  int dimf = 0;
-  int dimh = 1;
-  int dimw = 2;
-
-  if (ndim == 4) {
-    dimf++;
-    dimh++;
-    dimw++;
-  }
-
-  THCUNN_argCheck(state, ndim == 3 || ndim == 4, 2, input,
-                  "3D or 4D input tensor expected but got: %s");
-
-  int64_t nInputPlane  = weight->size[1];
-  int64_t inputHeight  = input->size[dimh];
-  int64_t inputWidth   = input->size[dimw];
-  int64_t nOutputPlane = weight->size[0];
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-
-  if (outputWidth < 1 || outputHeight < 1)
-      THError("Given input size: (%d x %d x %d). "
-              "Calculated output size: (%d x %d x %d). Output size is too small",
-              nInputPlane,inputHeight,inputWidth,nOutputPlane*nInputPlane,outputHeight,outputWidth);
-
-  THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
-
-  if (gradOutput != NULL) {
-    THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimf, nInputPlane);
-    THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimh, nOutputPlane);
-    THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimw, outputHeight);
-    THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimw + 1, outputWidth);
-  }
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           THCTensor *weight,
-           THCTensor *bias,
-           THCTensor *columns,
-           THCTensor *ones,
-           int kW, int kH,
-           int dW, int dH,
-           int padW, int padH) {
-
-  THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones);
-  if (bias) {
-    THCUNN_assertSameGPU(state, 2, weight, bias);
-  }
-
-  // Params:
-  int nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
-  int nOutputPlane = weight->size[0];
-  if (weight->nDimension == 2) {
-    THCTensor_(resize4d)(state, weight, nOutputPlane, nInputPlane, kH, kW);
-  }
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-       (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW);
-
-
-  // Transpose weight & bias
-  THCTensor *_weight = THCTensor_(newTranspose)(state, weight, 0, 1);
-  weight = THCTensor_(newContiguous)(state, _weight);
-
-  THCTensor *_bias = NULL;
-  if(bias) {
-    _bias = THCTensor_(newTranspose)(state, bias, 0, 1);
-    bias = THCTensor_(newContiguous)(state, _bias);
-  }
-
-  // resize weight
-  int64_t s1 = weight->size[0];
-  int64_t s2 = weight->size[1];
-  int64_t s3 = weight->size[2] * weight->size[3];
-  weight = THCTensor_(newWithStorage3d)(state, weight->storage, weight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-  input = THCTensor_(newContiguous)(state, input);
-
-  int batch = 1;
-  if (input->nDimension == 3) {
-    // Force batch
-    batch = 0;
-    THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
-  }
-
-  int64_t inputWidth   = input->size[3];
-  int64_t inputHeight  = input->size[2];
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
-  // Batch size + input planes
-  int64_t batchSize = input->size[0];
-
-  // Resize output
-  THCTensor_(resize5d)(state, output, batchSize, nInputPlane, nOutputPlane, outputHeight, outputWidth);
-
-  // Resize temporary columns
-  THCTensor_(resize2d)(state, columns, kW*kH, outputHeight*outputWidth);
-
-  // Define a buffer of ones, for bias accumulation
-  // Note: this buffer can be shared with other modules, it only ever gets increased,
-  // and always contains ones.
-  if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
-    // Resize plane and fill with ones...
-    THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
-    THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
-  }
-
-  // Helpers
-  THCTensor *input_n = THCTensor_(new)(state);
-  THCTensor *output_n = THCTensor_(new)(state);
-
-
-  // Helpers for DepthWiseConvolution
-  THCTensor *input_i = THCTensor_(new)(state);
-  THCTensor *output_i = THCTensor_(new)(state);
-  THCTensor *weight_i = THCTensor_(new)(state);
-
-  THCTensor *bias_i = NULL;
-  if(bias) {
-    bias_i = THCTensor_(new)(state);
-  }
-  // For each elt in batch, do:
-  for (int elt = 0; elt < batchSize; elt ++) {
-    // Matrix mulitply per output:
-    THCTensor_(select)(state, input_n, input, 0, elt);
-    THCTensor_(select)(state, output_n, output, 0, elt);
-
-
-    for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
-    {
-      // Fetch ipelt-th input plane
-      THCTensor_(narrow)(state, input_i, input_n, 0, ipelt, 1);
-      THCTensor_(select)(state, output_i, output_n, 0, ipelt);
-      THCTensor_(select)(state, weight_i, weight, 0, ipelt);
-      if (bias) {
-        THCTensor_(select)(state, bias_i, bias, 0, ipelt);
-      }
-      // Do Bias first:
-      // M,N,K are dims of matrix A and B
-      // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-      int64_t m_ = nOutputPlane;
-      int64_t n_ = outputHeight * outputWidth;
-      int64_t k_ = 1;
-
-      // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
-      if (bias) {
-        #ifdef THC_REAL_IS_FLOAT
-        THCudaBlas_Sgemm(
-        #elif defined(THC_REAL_IS_HALF)
-        THCudaBlas_Hgemm(
-        #elif defined(THC_REAL_IS_DOUBLE)
-        THCudaBlas_Dgemm(
-        #endif
-            state,
-            't', 'n',
-            n_, m_, k_,
-            ScalarConvert<int, real>::to(1),
-            THCTensor_(data)(state, ones), k_,
-            THCTensor_(data)(state, bias_i), k_,
-            ScalarConvert<int, real>::to(0),
-            THCTensor_(data)(state, output_i), n_
-        );
-      } else {
-        THCTensor_(zero)(state, output_i);
-      }
-
-      // Extract columns:
-      im2col(
-        THCState_getCurrentStream(state),
-        THCTensor_(data)(state, input_i),
-        1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
-        1, 1, THCTensor_(data)(state, columns)
-      );
-
-      // M,N,K are dims of matrix A and B
-      // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-      int64_t m = nOutputPlane;
-      int64_t n = columns->size[1];
-      int64_t k = 1*kH*kW;
-
-      // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
-      #ifdef THC_REAL_IS_FLOAT
-      THCudaBlas_Sgemm(
-      #elif defined(THC_REAL_IS_HALF)
-      THCudaBlas_Hgemm(
-      #elif defined(THC_REAL_IS_DOUBLE)
-      THCudaBlas_Dgemm(
-      #endif
-          state,
-          'n', 'n',
-          n, m, k,
-          ScalarConvert<int, real>::to(1),
-          THCTensor_(data)(state, columns), n,
-          THCTensor_(data)(state, weight_i), k,
-          ScalarConvert<int, real>::to(1),
-          THCTensor_(data)(state, output_i), n
-      );
-    }
-  }
-
-  // Free
-  THCTensor_(free)(state, input_n);
-  THCTensor_(free)(state, output_n);
-
-  THCTensor_(free)(state, input_i);
-  THCTensor_(free)(state, output_i);
-
-  THCTensor_(free)(state, weight_i);
-
-  THCTensor_(free)(state, weight);
-  THCTensor_(free)(state, _weight);
-
-  THCTensor_(free)(state, bias_i);
-  THCTensor_(free)(state, bias);
-  THCTensor_(free)(state, _bias);
-  // Transpose output
-  THCTensor_(resize4d)(state, output, batchSize, nInputPlane * nOutputPlane, outputHeight, outputWidth);
-
-  // Make a contiguous copy of output (OPTIONAL)
-  // THCTensor *_output = THCTensor_(newContiguous)(state, output);
-
-  // Resize output
-  if (batch == 0) {
-    THCTensor_(select)(state, output, NULL, 0, 0);
-    THCTensor_(select)(state, input, NULL, 0, 0);
-  }
-  //else
-    //THCTensor_(resize5d)(state, output, batchSize, nOutputPlane, nInputPlane, outputHeight, outputWidth);
-
-  // Copy output back
-  // THCTensor_(freeCopyTo)(state, _output, output);
-
-  THCTensor_(free)(state, input);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           THCTensor *weight,
-           THCTensor *gradColumns,
-           THCTensor *ones,
-           int kW, int kH,
-           int dW, int dH,
-           int padW, int padH) {
-
-  THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
-                       gradColumns, gradInput);
-
-  // Params:
-  int nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
-  int nOutputPlane = weight->size[0];
-  if (weight->nDimension == 2) {
-    THCTensor_(resize4d)(state, weight, nOutputPlane, nInputPlane, kH, kW);
-  }
-
-  gradOutput = THCTensor_(newWithTensor)(state, gradOutput);
-
-  if (input->nDimension == 3) {
-    if (gradOutput->nDimension == 3) {
-      THCTensor_(resize4d)(state, gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
-    }
-  }
-  else
-  {
-    if (gradOutput->nDimension == 4) {
-      THCTensor_(resize5d)(state, gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
-    }
-  }
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-       (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW);
-
-  // Transpose weight
-  THCTensor *_weight = THCTensor_(newTranspose)(state, weight, 0, 1);
-  weight = THCTensor_(newContiguous)(state, _weight);
-
-  // resize weight
-  int64_t s1 = weight->size[0];
-  int64_t s2 = weight->size[1];
-  int64_t s3 = weight->size[2] * weight->size[3];
-  weight = THCTensor_(newWithStorage3d)(state, weight->storage, weight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-
-
-  input = THCTensor_(newContiguous)(state, input);
-
-
-  int batch = 1;
-  if (input->nDimension == 3) {
-    // Force batch
-    batch = 0;
-    THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
-    THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
-  }
-
-  int64_t inputWidth   = input->size[3];
-  int64_t inputHeight  = input->size[2];
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
-  // Batch size + input planes
-  int64_t batchSize = input->size[0];
-
-  // Resize output
-  THCTensor_(resize4d)(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
-
-  // Resize temporary columns
-  THCTensor_(resize2d)(state, gradColumns, 1*kW*kH, outputHeight*outputWidth);
-
-  // Helpers
-  THCTensor *gradInput_n = THCTensor_(new)(state);
-  THCTensor *gradOutput_n = THCTensor_(new)(state);
-
-  // Helpers for DepthWiseConvolution
-  THCTensor *gradOutput_i = THCTensor_(new)(state);
-  THCTensor *gradInput_i = THCTensor_(new)(state);
-  THCTensor *weight_i = THCTensor_(new)(state);
-
-  // For each elt in batch, do:
-  for (int elt = 0; elt < batchSize; elt ++) {
-    // Matrix mulitply per sample:
-    THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
-    THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
-    for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
-      {
-      // M,N,K are dims of matrix A and B
-      // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-
-      // Fetch ipelt-th input plane
-      THCTensor_(narrow)(state, gradInput_i, gradInput_n, 0, ipelt, 1);
-      THCTensor_(select)(state, gradOutput_i, gradOutput_n, 0, ipelt);
-      THCTensor_(select)(state, weight_i, weight, 0, ipelt);
-
-      int64_t m = 1*kW*kH;
-      int64_t n = gradColumns->size[1];
-      int64_t k = nOutputPlane;
-
-      // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
-      #ifdef THC_REAL_IS_FLOAT
-      THCudaBlas_Sgemm(
-      #elif defined(THC_REAL_IS_HALF)
-      THCudaBlas_Hgemm(
-      #elif defined(THC_REAL_IS_DOUBLE)
-      THCudaBlas_Dgemm(
-      #endif
-          state,
-          'n', 't',
-          n, m, k,
-          ScalarConvert<int, real>::to(1),
-          THCTensor_(data)(state, gradOutput_i), n,
-          THCTensor_(data)(state, weight_i), m,
-          ScalarConvert<int, real>::to(0),
-          THCTensor_(data)(state, gradColumns), n
-      );
-
-      // Unpack columns back into input:
-      col2im<real, accreal>(
-        THCState_getCurrentStream(state),
-        THCTensor_(data)(state, gradColumns),
-        1, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
-        1, 1, THCTensor_(data)(state, gradInput_i)
-      );
-      }
-  }
-
-  // Free
-  THCTensor_(free)(state, gradInput_n);
-  THCTensor_(free)(state, gradOutput_n);
-
-  THCTensor_(free)(state, gradInput_i);
-  THCTensor_(free)(state, gradOutput_i);
-  THCTensor_(free)(state, weight_i);
-
-  // Resize output
-  if (batch == 0) {
-    THCTensor_(select)(state, gradOutput, NULL, 0, 0);
-    THCTensor_(select)(state, input, NULL, 0, 0);
-    THCTensor_(select)(state, gradInput, NULL, 0, 0);
-  }
-
-  THCTensor_(free)(state, input);
-  THCTensor_(free)(state, gradOutput);
-  THCTensor_(free)(state, weight);
-  THCTensor_(free)(state, _weight);
-}
-
-void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradWeight,
-           THCTensor *gradBias,
-           THCTensor *columns,
-           THCTensor *ones,
-           int kW, int kH,
-           int dW, int dH,
-           int padW, int padH,
-           accreal scale_) {
-
-  real scale = ScalarConvert<accreal, real>::to(scale_);
-
-  THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
-  if (gradBias) {
-   THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
-  }
-
-  // Params
-  int nInputPlane = gradWeight->nDimension == 2 ? gradWeight->size[1]/(kW*kH) : gradWeight->size[1];
-  int nOutputPlane = gradWeight->size[0];
-  if (gradWeight->nDimension == 2) {
-    THCTensor_(resize4d)(state, gradWeight, nOutputPlane, nInputPlane, kH, kW);
-  }
-
- gradOutput = THCTensor_(newWithTensor)(state, gradOutput);
-  if (input->nDimension == 3) {
-    if (gradOutput->nDimension == 3) {
-      THCTensor_(resize4d)(state, gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
-    }
-  }
-  else
-  {
-    if (gradOutput->nDimension == 4) {
-      THCTensor_(resize5d)(state, gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
-    }
-  }
-
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-       (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW);
-
-  // Transpose gradWeight & gradBias
-  THCTensor_(transpose)(state, gradWeight, NULL, 0, 1);
-
-
-  THCTensor *_gradBias = NULL;
-  if(gradBias) {
-    THCTensor_(transpose)(state, gradBias, NULL, 0, 1);
-    _gradBias = gradBias;
-    gradBias = THCTensor_(newContiguous)(state, gradBias);
-
-  }
-
-  THCTensor *_gradWeight;
-
-  _gradWeight = gradWeight;
-
-  gradWeight = THCTensor_(newContiguous)(state, gradWeight);
-
-
-  // resize gradWeight
-  int64_t s1 = gradWeight->size[0];
-  int64_t s2 = gradWeight->size[1];
-  int64_t s3 = gradWeight->size[2] * gradWeight->size[3];
-  gradWeight = THCTensor_(newWithStorage3d)(state, gradWeight->storage, gradWeight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-  input = THCTensor_(newContiguous)(state, input);
-
-  int batch = 1;
-  if (input->nDimension == 3) {
-    // Force batch
-    batch = 0;
-    THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
-    THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
-  }
-
-  int64_t inputWidth   = input->size[3];
-  int64_t inputHeight  = input->size[2];
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
-  // Batch size + input planes
-  int64_t batchSize = input->size[0];
-
-  // Define a buffer of ones, for bias accumulation
-  if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
-    // Resize plane and fill with ones...
-    THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
-    THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
-  }
-
-  // Resize temporary columns
-  THCTensor_(resize2d)(state, columns, 1*kW*kH, outputHeight*outputWidth);
-
-  // Helpers
-  THCTensor *input_n = THCTensor_(new)(state);
-  THCTensor *gradOutput_n = THCTensor_(new)(state);
-
-  // Helpers for DepthWiseConvolution
-  THCTensor *gradOutput_i = THCTensor_(new)(state);
-  THCTensor *input_i = THCTensor_(new)(state);
-  THCTensor *gradWeight_i = THCTensor_(new)(state);
-
-  THCTensor *gradBias_i = NULL;
-  if(gradBias) {
-    gradBias_i = THCTensor_(new)(state);
-  }
-
-  // For each elt in batch, do:
-  for (int elt = 0; elt < batchSize; elt ++) {
-    // Matrix mulitply per output:
-    THCTensor_(select)(state, input_n, input, 0, elt);
-    THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
-    for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
-    {
-      THCTensor_(narrow)(state, input_i, input_n, 0, ipelt, 1);
-      THCTensor_(select)(state, gradOutput_i, gradOutput_n, 0, ipelt);
-      THCTensor_(select)(state, gradWeight_i, gradWeight, 0, ipelt);
-      if (gradBias) {
-        THCTensor_(select)(state, gradBias_i, gradBias, 0, ipelt);
-      }
-
-      // Extract columns:
-      im2col(
-        THCState_getCurrentStream(state),
-        THCTensor_(data)(state, input_i),
-        1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
-        1, 1, THCTensor_(data)(state, columns)
-      );
-
-      // M,N,K are dims of matrix A and B
-      // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-      int64_t m = nOutputPlane;
-      int64_t n = 1*kW*kH;
-      int64_t k = columns->size[1];
-
-      // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
-      #ifdef THC_REAL_IS_FLOAT
-      THCudaBlas_Sgemm(
-      #elif defined(THC_REAL_IS_HALF)
-      THCudaBlas_Hgemm(
-      #elif defined(THC_REAL_IS_DOUBLE)
-      THCudaBlas_Dgemm(
-      #endif
-          state,
-          't', 'n',
-          n, m, k,
-          scale,
-          THCTensor_(data)(state, columns), k,
-          THCTensor_(data)(state, gradOutput_i), k,
-          ScalarConvert<int, real>::to(1),
-          THCTensor_(data)(state, gradWeight_i), n
-      );
-
-      // Do Bias:
-      // M,N,K are dims of matrix A and B
-      // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-      int64_t m_ = nOutputPlane;
-      int64_t k_ = outputHeight * outputWidth;
-
-      // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
-      if (gradBias) {
-        #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
-        #ifdef THC_REAL_IS_FLOAT
-        THCudaBlas_Sgemv(
-        #elif defined(THC_REAL_IS_DOUBLE)
-        THCudaBlas_Dgemv(
-        #endif
-            state,
-            't',
-            k_, m_,
-            scale,
-            THCTensor_(data)(state, gradOutput_i), k_,
-            THCTensor_(data)(state, ones), 1,
-            ScalarConvert<int, real>::to(1),
-            THCTensor_(data)(state, gradBias_i), 1
-        );
-        #endif
-        #ifdef THC_REAL_IS_HALF
-        THCudaBlas_Hgemm(
-            state,
-            't', 'n',
-            m_, 1, k_,
-            scale,
-            THCTensor_(data)(state, gradOutput_i), k_,
-            THCTensor_(data)(state, ones), k_,
-            ScalarConvert<int, real>::to(1),
-            THCTensor_(data)(state, gradBias_i), m_
-        );
-        #endif
-      }
-    }
-  }
-
-
-  // Copy back and transpose back
-  THCTensor_(transpose)(state, _gradWeight, NULL, 0, 1);
-  THCTensor_(resize4d)(state, _gradWeight, nInputPlane, nOutputPlane, kH, kW);
-  THCTensor_(copy)(state, _gradWeight, gradWeight);
-  THCTensor_(transpose)(state, _gradWeight, NULL, 0, 1);
-
-  if(gradBias) {
-    THCTensor_(transpose)(state, _gradBias, NULL, 0, 1);
-    THCTensor_(resize2d)(state, _gradBias, nInputPlane, nOutputPlane);
-    THCTensor_(copy)(state, _gradBias, gradBias);
-    THCTensor_(transpose)(state, _gradBias, NULL, 0, 1);
-  }
-
-
-  // Free
-  THCTensor_(free)(state, input_n);
-  THCTensor_(free)(state, gradOutput_n);
-  THCTensor_(free)(state, input_i);
-  THCTensor_(free)(state, gradOutput_i);
-  THCTensor_(free)(state, gradWeight_i);
-  THCTensor_(free)(state, gradWeight);
-  THCTensor_(free)(state, gradBias_i);
-  THCTensor_(free)(state, gradBias);
-
-  // Resize
-  if (batch == 0) {
-    THCTensor_(select)(state, gradOutput, NULL, 0, 0);
-    THCTensor_(select)(state, input, NULL, 0, 0);
-  }
-
-  THCTensor_(free)(state, input);
-  THCTensor_(free)(state, gradOutput);
-}
-
-#endif
diff --git a/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu b/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu
new file mode 100644
index 0000000..98ad76e
--- /dev/null
+++ b/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu
@@ -0,0 +1,201 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/SpatialDepthwiseConvolution.cu"
+#else
+
+void THNN_(SpatialDepthwiseConvolution_updateOutput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *output,
+                  THCTensor *weight,
+                  THCTensor *bias,
+                  int kW, int kH,
+                  int dW, int dH,
+                  int padW, int padH,
+                  int dilationW, int dilationH)
+{
+  THCUNN_assertSameGPU(state, 3, input, output, weight);
+
+  // Only handle 4D Input Tensors for now
+  THAssert(THCTensor_(nDimension)(state, input) == 4);
+  THAssert(THCTensor_(nDimension)(state, weight) == 4);
+
+  // We assume that the input and weight Tensors are shaped properly by
+  // the caller, so we verify that here to some extent
+
+  // Weight Tensor is shape (output_channels, 1, kH, kW)
+  THAssert(weight->size[1] == 1);
+
+  // Input Tensor is shape (N, input_channels, H, W)
+  // We verify that the # of output_channels is a multiple of input_channels
+  THAssert(weight->size[0] % input->size[1] == 0);
+
+  // Bias has same # of channels as output
+  if (bias) {
+    THAssert(bias->size[0] == weight->size[0]);
+  }
+
+  // Following the behvaior of other THCUNN functions, we shape the output
+  // Tensor ourselves
+
+  int batchSize = input->size[0];
+  int height = input->size[2];
+  int width = input->size[3];
+  int outputHeight = (height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+  int outputWidth = (width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  int outputChannels = weight->size[0];
+
+  THCTensor_(resize4d)(state, output, batchSize, outputChannels, outputHeight, outputWidth);
+
+  THCDeviceTensor<real, 4> dInput = toDeviceTensor<real, 4>(state, input);
+  THCDeviceTensor<real, 4> dWeight = toDeviceTensor<real, 4>(state, weight);
+  THCDeviceTensor<real, 4> dOutput = toDeviceTensor<real, 4>(state, output);
+  THCDeviceTensor<real, 1> dBias;
+  if (bias) {
+    dBias = toDeviceTensor<real, 1>(state, bias);
+  }
+
+  // Kernel currently relies upon all the Tensors to be contiguous
+  THAssert(dInput.isContiguous());
+  THAssert(dWeight.isContiguous());
+  THAssert(dOutput.isContiguous());
+
+  int inputChannels = input->size[1];
+  int depthwiseMultiplier = outputChannels / inputChannels;
+
+  // One thread per output value
+  int n = THCTensor_(nElement)(state, output);
+  int blocks = GET_BLOCKS(n);
+  dim3 grid(blocks);
+  dim3 block(CUDA_NUM_THREADS);
+
+  spatialDepthwiseConvolutionUpdateOutput<real, accreal, unsigned int><<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+    dInput, dOutput, dWeight, dBias, bias != NULL, n, outputChannels, depthwiseMultiplier,
+    width, height, outputWidth, outputHeight,
+    kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+  THCudaCheck(cudaGetLastError());
+}
+
+void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *gradOutput,
+                  THCTensor *gradInput,
+                  THCTensor *weight,
+                  int kW, int kH,
+                  int dW, int dH,
+                  int padW, int padH,
+                  int dilationW, int dilationH)
+{
+  THCUNN_assertSameGPU(state, 3, gradOutput, gradInput, weight);
+
+  // Only handle 4D Input Tensors for now
+  THAssert(THCTensor_(nDimension)(state, input) == 4);
+  THAssert(THCTensor_(nDimension)(state, weight) == 4);
+  THAssert(THCTensor_(nDimension)(state, gradOutput) == 4);
+
+  // Minimal shape checking, as above
+  // Same # of elements in batch
+  THAssert(input->size[0] == gradOutput->size[0]);
+  // Same # of filters as outputChannels
+  THAssert(weight->size[0] == gradOutput->size[1]);
+
+  // Resize GradInput
+  THCTensor_(resizeAs)(state, gradInput, input);
+
+  int inputChannels = input->size[1];
+  int height = input->size[2];
+  int width = input->size[3];
+
+  int outputChannels = gradOutput->size[1];
+  int outputHeight = gradOutput->size[2];
+  int outputWidth = gradOutput->size[3];
+
+  int depthwiseMultiplier = outputChannels / inputChannels;
+
+  THCDeviceTensor<real, 4> dGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
+  THCDeviceTensor<real, 4> dGradInput = toDeviceTensor<real, 4>(state, gradInput);
+  THCDeviceTensor<real, 4> dWeight = toDeviceTensor<real, 4>(state, weight);
+
+  // Kernel currently relies upon all the Tensors to be contiguous
+  THAssert(dGradOutput.isContiguous());
+  THAssert(dGradInput.isContiguous());
+  THAssert(dWeight.isContiguous());
+
+  // One thread per gradInput value
+  int n = THCTensor_(nElement)(state, gradInput);
+  int blocks = GET_BLOCKS(n);
+  dim3 grid(blocks);
+  dim3 block(CUDA_NUM_THREADS);
+
+  spatialDepthwiseConvolutionUpdateGradInput<real, accreal, unsigned int><<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+    dGradOutput, dGradInput, dWeight, n, inputChannels, depthwiseMultiplier, outputChannels, width,
+    height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+  THCudaCheck(cudaGetLastError());
+}
+
+void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *gradOutput,
+                  THCTensor *gradWeight,
+                  int kW, int kH,
+                  int dW, int dH,
+                  int padW, int padH,
+                  int dilationW, int dilationH)
+{
+  THCUNN_assertSameGPU(state, 3, input, gradOutput, gradWeight);
+
+  // Only handle 4D Input Tensors for now
+  THAssert(THCTensor_(nDimension)(state, input) == 4);
+  THAssert(THCTensor_(nDimension)(state, gradOutput) == 4);
+  THAssert(THCTensor_(nDimension)(state, gradWeight) == 4);
+
+  // Minimal shape checking as above
+  // Same # of elements in batch
+  THAssert(input->size[0] == gradOutput->size[0]);
+  // Same # of filters as outputChannels
+  THAssert(gradWeight->size[0] == gradOutput->size[1]);
+
+  int batchSize = input->size[0];
+  int inputChannels = input->size[1];
+  int height = input->size[2];
+  int width = input->size[3];
+
+  int outputChannels = gradOutput->size[1];
+  int outputHeight = gradOutput->size[2];
+  int outputWidth = gradOutput->size[3];
+
+  int depthwiseMultiplier = outputChannels / inputChannels;
+
+  THCDeviceTensor<real, 4> dGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
+  THCDeviceTensor<real, 4> dInput = toDeviceTensor<real, 4>(state, input);
+  THCDeviceTensor<real, 4> dGradWeight = toDeviceTensor<real, 4>(state, gradWeight);
+
+  // Kernel currently relies upon all the Tensors to be contiguous
+  THAssert(dGradOutput.isContiguous());
+  THAssert(dInput.isContiguous());
+  THAssert(dGradWeight.isContiguous());
+
+  // We parallelize so that each block computes a single value in gradWeight
+  int blocks = outputChannels * kH * kW;
+
+  // Because each weight position is a function of convolving the gradOutput over
+  // the input, we need batchSize * outputHeight * outputWidth individual calculations
+  int n = batchSize * outputHeight * outputWidth;
+
+  // Make sure we have enough threads to perform the reduction, and use this number
+  // to create the shared memory size for the reduction
+  dim3 grid(blocks);
+  dim3 block(std::min(nextHighestPowerOf2(n), (unsigned int64_t) CUDA_NUM_THREADS));
+  int smem = block.x * sizeof(accreal);
+
+  spatialDepthwiseConvolutionAccGradParameters<real, accreal, unsigned int><<<grid, block, smem, THCState_getCurrentStream(state)>>>(
+      dGradOutput, dInput, dGradWeight, batchSize, inputChannels, outputChannels, depthwiseMultiplier, n,
+      width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+  THCudaCheck(cudaGetLastError());
+}
+
+#endif
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index cabc322..aa1842f 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -664,43 +664,37 @@
                   int padW, int padH,
                   accreal scale);
 
-TH_API void THNN_(SpatialDepthWiseConvolution_updateOutput)(
+TH_API void THNN_(SpatialDepthwiseConvolution_updateOutput)(
                   THCState *state,
                   THCTensor *input,
                   THCTensor *output,
                   THCTensor *weight,
                   THCTensor *bias,              // [OPTIONAL]
-                  THCTensor *columns,
-                  THCTensor *ones,
                   int kW, int kH,
                   int dW, int dH,
-                  int padW, int padH);
+                  int padW, int padH,
+                  int dilationW, int dilationH);
 
-TH_API void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
+TH_API void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
                   THCState *state,
                   THCTensor *input,
                   THCTensor *gradOutput,
                   THCTensor *gradInput,
                   THCTensor *weight,
-                  THCTensor *columns,
-                  THCTensor *ones,
                   int kW, int kH,
                   int dW, int dH,
-                  int padW, int padH);
+                  int padW, int padH,
+                  int dilationW, int dilationH);
 
-TH_API void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
+TH_API void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
                   THCState *state,
                   THCTensor *input,
                   THCTensor *gradOutput,
                   THCTensor *gradWeight,
-                  THCTensor *gradBias,          // [OPTIONAL]
-                  THCTensor *columns,
-                  THCTensor *ones,
                   int kW, int kH,
                   int dW, int dH,
                   int padW, int padH,
-                  accreal scale);
-
+                  int dilationW, int dilationH);
 
 TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
                   THCState *state,
diff --git a/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c b/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c
deleted file mode 100644
index 85a9666..0000000
--- a/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c
+++ /dev/null
@@ -1,528 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/SpatialDepthWiseConvolution.c"
-#else
-
-static inline void THNN_(SpatialDepthWiseConvolution_shapeCheck)(
-	THTensor *input, THTensor *gradOutput,
-	THTensor *weight, THTensor *bias,
-	int kH, int kW, int dH, int dW, int padH, int padW) {
-
-  THArgCheck(kW > 0 && kH > 0, 9,
-	       "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
-  THArgCheck(dW > 0 && dH > 0, 11,
-	     "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
-  THNN_ARGCHECK(weight->nDimension == 4, 5, weight,
-		"2D or 4D weight tensor expected, but got: %s");
-
-  if (bias != NULL) {
-    THNN_CHECK_DIM_SIZE(bias, 2, 0, weight->size[0]);
-    THNN_CHECK_DIM_SIZE(bias, 2, 1, weight->size[1]);
-  }
-
-  int ndim = input->nDimension;
-  int dimf = 0;
-  int dimh = 1;
-  int dimw = 2;
-
-  if (ndim == 4) {
-    dimf++;
-    dimh++;
-    dimw++;
-  }
-
-  THNN_ARGCHECK(ndim == 3 || ndim == 4, 2, input,
-		"3D or 4D input tensor expected but got: %s");
-
-  int64_t nInputPlane  = weight->size[1];
-  int64_t inputHeight  = input->size[dimh];
-  int64_t inputWidth   = input->size[dimw];
-  int64_t nOutputPlane = weight->size[0];
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-
-  if (outputWidth < 1 || outputHeight < 1)
-    THError("Given input size: (%d x %d x %d). "
-	    "Calculated output size: (%d x %d x %d). Output size is too small",
-	    nInputPlane,inputHeight,inputWidth,nOutputPlane*nInputPlane,outputHeight,outputWidth);
-
-  THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
-
-  if (gradOutput != NULL) {
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimf, nInputPlane);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimh, nOutputPlane);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimw, outputHeight);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimw + 1, outputWidth);
-  }
-}
-
-static void THNN_(SpatialDepthWiseConvolution_updateOutput_frame)(
-          THTensor *input,
-          THTensor *output,
-          THTensor *weight,
-          THTensor *bias,
-          THTensor *finput,
-          int kW,
-          int kH,
-          int dW,
-          int dH,
-          int padW,
-          int padH,
-          int64_t nInputPlane,
-          int64_t inputWidth,
-          int64_t inputHeight,
-          int64_t nOutputPlane,
-          int64_t outputWidth,
-          int64_t outputHeight)
-{
-  int64_t i;
-  THTensor *output2d;
-
-  THNN_(unfolded_copy)(finput, input, kW, kH, dW, dH, padW, padH,
-		       nInputPlane, inputWidth, inputHeight,
-		       outputWidth, outputHeight);
-
-  output2d = THTensor_(newWithStorage2d)(output->storage, output->storageOffset,
-                                         nOutputPlane, -1,
-                                         outputHeight*outputWidth, -1);
-  if (bias) {
-    for(i = 0; i < nOutputPlane; i++)
-        THVector_(fill)
-	  (output->storage->data + output->storageOffset + output->stride[0] * i,
-	   THTensor_(get1d)(bias, i), outputHeight*outputWidth);
-  } else {
-    THTensor_(zero)(output);
-  }
-
-  THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput);
-
-  THTensor_(free)(output2d);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          THTensor *weight,
-          THTensor *bias,
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW,
-          int kH,
-          int dW,
-          int dH,
-          int padW,
-          int padH)
-{
-  int64_t nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
-  int64_t nOutputPlane = weight->size[0];
-  if (weight->nDimension == 2) {
-    THTensor_(resize4d)(weight, nOutputPlane, nInputPlane, kH, kW);
-  }
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-    (input, NULL, weight, bias, kH, kW, dH, dW, padH, padW);
-
-  THTensor *_weight = THTensor_(newTranspose)(weight, 0, 1);
-  weight = THTensor_(newContiguous)(_weight);
-
-  THTensor *_bias = NULL;
-  if(bias) {
-  	_bias = THTensor_(newTranspose)(bias, 0, 1);
-  	bias = THTensor_(newContiguous)(_bias);
-  }
-
-  // resize weight
-  int64_t s1 = weight->size[0];
-  int64_t s2 = weight->size[1];
-  int64_t s3 = weight->size[2] * weight->size[3];
-  weight = THTensor_(newWithStorage3d)(weight->storage, weight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-  input = THTensor_(newContiguous)(input);
-
-  int ndim = input->nDimension;
-
-  int batch = 1;
-  if (ndim == 3) {
-    // Force batch
-    batch = 0;
-    THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
-  }
-
-  int64_t inputHeight  = input->size[3];
-  int64_t inputWidth   = input->size[2];
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-
-  int64_t T = input->size[0];
-  int64_t t;
-
-  THTensor_(resize5d)(output, T, nInputPlane, nOutputPlane, outputHeight, outputWidth);
-  THTensor_(resize4d)(finput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
-#pragma omp parallel for private(t)
-  for(t = 0; t < T; t++)
-  {
-    THTensor *input_t = THTensor_(newSelect)(input, 0, t);
-    THTensor *output_t = THTensor_(newSelect)(output, 0, t);
-    THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
-
-    int64_t i;
-#pragma omp parallel for private(i)
-    for(i = 0; i < nInputPlane; i++)
-    {
-      THTensor *weight_i = THTensor_(newSelect)(weight, 0, i);
-      THTensor *input_i = THTensor_(newNarrow)(input_t, 0, i, 1);
-      THTensor *output_i = THTensor_(newSelect)(output_t, 0, i);
-      THTensor *finput_i = THTensor_(newSelect)(finput_t, 0, i);
-      THTensor *bias_i = NULL;
-      if(bias) {
-        bias_i = THTensor_(newSelect)(bias, 0, i);
-      }
-      THNN_(SpatialDepthWiseConvolution_updateOutput_frame)
-	(input_i, output_i, weight_i, bias_i, finput_i,
-	 kW, kH, dW, dH, padW, padH,
-	 1, inputWidth, inputHeight,
-	 nOutputPlane, outputWidth, outputHeight);
-
-      THTensor_(free)(input_i);
-      THTensor_(free)(weight_i);
-      THTensor_(free)(bias_i);
-      THTensor_(free)(output_i);
-      THTensor_(free)(finput_i);
-    }
-    THTensor_(free)(input_t);
-    THTensor_(free)(output_t);
-    THTensor_(free)(finput_t);
-  }
-
-  THTensor_(free)(weight);
-  THTensor_(free)(_weight);
-  THTensor_(free)(bias);
-  THTensor_(free)(_bias);
-  THTensor_(resize4d)(output, T, nInputPlane * nOutputPlane, outputHeight, outputWidth);
-
-  if (batch == 0) {
-    THTensor_(select)(output, NULL, 0, 0);
-    THTensor_(select)(input, NULL, 0, 0);
-    THTensor_(select)(finput, NULL, 0, 0);
-  }
-  THTensor_(free)(input);
-}
-
-static void THNN_(SpatialDepthWiseConvolution_updateGradInput_frame)(
-          THTensor *gradInput,
-          THTensor *gradOutput,
-          THTensor *weight,
-          THTensor *fgradInput,
-          int kW,
-          int kH,
-          int dW,
-          int dH,
-          int padW,
-          int padH)
-{
-  THTensor *gradOutput2d = THTensor_(newWithStorage2d)
-    (gradOutput->storage, gradOutput->storageOffset,
-     gradOutput->size[0], -1,
-     gradOutput->size[1]*gradOutput->size[2], -1);
-  THTensor_(addmm)(fgradInput, 0, fgradInput, 1, weight, gradOutput2d);
-  THTensor_(free)(gradOutput2d);
-
-  THTensor_(zero)(gradInput);
-
-  THNN_(unfolded_acc)(fgradInput, gradInput, kW, kH, dW, dH,
-		      padW, padH,
-		      gradInput->size[0], gradInput->size[2], gradInput->size[1],
-		      gradOutput->size[2], gradOutput->size[1]);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *weight,
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW,
-          int kH,
-          int dW,
-          int dH,
-          int padW,
-          int padH)
-{
-  int64_t nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
-  int64_t nOutputPlane = weight->size[0];
-  if (weight->nDimension == 2) {
-    THTensor_(resize4d)(weight, nOutputPlane, nInputPlane, kH, kW);
-  }
-  gradOutput = THTensor_(newWithTensor)(gradOutput);
-
-  if (input->nDimension == 3) {
-    if (gradOutput->nDimension == 3) {
-      THTensor_(resize4d)(gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
-    }
-  }
-  else
-  {
-    if (gradOutput->nDimension == 4) {
-      THTensor_(resize5d)(gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
-    }
-  }
-
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-    (input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW);
-
-  THTensor *_weight = THTensor_(newTranspose)(weight, 0, 1);
-  weight = THTensor_(newContiguous)(_weight);
-
-
-  // resize weight
-  int64_t s1 = weight->size[0];
-  int64_t s2 = weight->size[1];
-  int64_t s3 = weight->size[2] * weight->size[3];
-  weight = THTensor_(newWithStorage3d)(weight->storage, weight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-  input = THTensor_(newContiguous)(input);
-
-  int batch = 1;
-  if (input->nDimension == 3) {
-    // Force batch
-    batch = 0;
-    THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
-    THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
-  }
-
-  int64_t inputHeight  = input->size[3];
-  int64_t inputWidth   = input->size[2];
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-
-  int64_t T = input->size[0];
-  int64_t t;
-
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(resize4d)(fgradInput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
-  // depending on the BLAS library, fgradInput (result tensor) might
-  // be left uninitialized on zero alpha, which might lead to weird behavior
-  // hence, to be safe, zero it
-  THTensor_(zero)(fgradInput);
-
-
-
-#pragma omp parallel for private(t)
-  for(t = 0; t < T; t++)
-  {
-    THTensor *gradInput_t = THTensor_(newSelect)(gradInput, 0, t);
-    THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
-    THTensor *fgradInput_t = THTensor_(newSelect)(fgradInput, 0, t);
-
-
-    int64_t i;
-#pragma omp parallel for private(i)
-    for(i = 0; i < nInputPlane; i++)
-    {
-      THTensor *weight_i = THTensor_(newSelect)(weight, 0, i);
-      THTensor *gradInput_i = THTensor_(newNarrow)(gradInput_t, 0, i, 1);
-      THTensor *gradOutput_i = THTensor_(newSelect)(gradOutput_t, 0, i);
-      THTensor *fgradInput_i = THTensor_(newSelect)(fgradInput_t, 0, i);
-
-      THTensor_(transpose)(weight_i, weight_i, 0, 1);
-
-      THNN_(SpatialDepthWiseConvolution_updateGradInput_frame)(gradInput_i, gradOutput_i,
-              weight_i, fgradInput_i,
-              kW, kH, dW, dH, padW, padH);
-
-      THTensor_(free)(gradInput_i);
-      THTensor_(free)(weight_i);
-      THTensor_(free)(gradOutput_i);
-      THTensor_(free)(fgradInput_i);
-    }
-
-    THTensor_(free)(gradInput_t);
-    THTensor_(free)(gradOutput_t);
-    THTensor_(free)(fgradInput_t);
-  }
-
-  if (batch == 0) {
-    THTensor_(select)(gradOutput, NULL, 0, 0);
-    THTensor_(select)(input, NULL, 0, 0);
-    THTensor_(select)(gradInput, NULL, 0, 0);
-    THTensor_(select)(fgradInput, NULL, 0, 0);
-  }
-
-  THTensor_(free)(input);
-  THTensor_(free)(gradOutput);
-  THTensor_(free)(weight);
-  THTensor_(free)(_weight);
-}
-
-static void THNN_(SpatialDepthWiseConvolution_accGradParameters_frame)(
-          THTensor *gradOutput,
-          THTensor *gradWeight,
-          THTensor *gradBias,
-          THTensor *finput,
-          accreal scale)
-{
-  int64_t i;
-  THTensor *gradOutput2d = THTensor_(newWithStorage2d)
-    (gradOutput->storage, gradOutput->storageOffset,
-     gradOutput->size[0], -1,
-     gradOutput->size[1]*gradOutput->size[2], -1);
-
-  THTensor_(transpose)(finput, finput, 0, 1);
-  THTensor_(addmm)(gradWeight, 1, gradWeight, scale, gradOutput2d, finput);
-  THTensor_(transpose)(finput, finput, 0, 1);
-
-  if (gradBias) {
-    for(i = 0; i < gradBias->size[0]; i++)
-    {
-      int64_t k;
-      real sum = 0;
-      real *data = gradOutput2d->storage->data + gradOutput2d->storageOffset + i*gradOutput2d->stride[0];
-      for(k = 0; k < gradOutput2d->size[1]; k++)
-        sum += data[k];
-      (gradBias->storage->data + gradBias->storageOffset)[i] += scale*sum;
-    }
-  }
-
-  THTensor_(free)(gradOutput2d);
-}
-
-void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradWeight,
-          THTensor *gradBias,
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW,
-          int kH,
-          int dW,
-          int dH,
-          int padW,
-          int padH,
-          accreal scale)
-{
-  int64_t nInputPlane = gradWeight->nDimension == 2 ? gradWeight->size[1]/(kH*kW) : gradWeight->size[1];
-  int64_t nOutputPlane = gradWeight->size[0];
-  if (gradWeight->nDimension == 2) {
-    THTensor_(resize4d)(gradWeight, nOutputPlane, nInputPlane, kH, kW);
-  }
-
-  gradOutput = THTensor_(newWithTensor)(gradOutput);
-  if (input->nDimension == 3) {
-    if (gradOutput->nDimension == 3) {
-      THTensor_(resize4d)(gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
-    }
-  }
-  else
-  {
-    if (gradOutput->nDimension == 4) {
-      THTensor_(resize5d)(gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
-    }
-  }
-
-
-  THNN_(SpatialDepthWiseConvolution_shapeCheck)
-    (input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW);
-
-  // Transpose gradWeight & gradBias
-  THTensor_(transpose)(gradWeight, NULL, 0, 1);
-  THTensor *_gradWeight;
-  _gradWeight = gradWeight;
-  gradWeight = THTensor_(newContiguous)(gradWeight);
-
-  THTensor *_gradBias = NULL;
-  if(gradBias) {
-	  THTensor_(transpose)(gradBias, NULL, 0, 1);
-	  _gradBias = gradBias;
-	  gradBias = THTensor_(newContiguous)(gradBias);
-  }
-
-  // resize gradWeight
-  int64_t s1 = gradWeight->size[0];
-  int64_t s2 = gradWeight->size[1];
-  int64_t s3 = gradWeight->size[2] * gradWeight->size[3];
-  gradWeight = THTensor_(newWithStorage3d)(gradWeight->storage, gradWeight->storageOffset,
-          s1, -1, s2, -1, s3, -1);
-
-  input = THTensor_(newContiguous)(input);
-
-
-  int batch = 1;
-  if (input->nDimension == 3) {
-    // Force batch
-    batch = 0;
-    THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
-    THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
-  }
-
-  int64_t inputHeight  = input->size[3];
-  int64_t inputWidth   = input->size[2];
-  int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-  int64_t outputWidth  = (inputWidth + 2*padW - kW) / dW + 1;
-
-  int64_t T = input->size[0];
-  int64_t t;
-  THTensor_(resize4d)(finput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
-  for(t = 0; t < T; t++)
-  {
-    THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
-    THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
-    int64_t i;
-#pragma omp parallel for private(i)
-    for(i = 0; i < nInputPlane; i++)
-    {
-      THTensor *finput_i = THTensor_(newSelect)(finput_t, 0, i);
-      THTensor *gradOutput_i = THTensor_(newSelect)(gradOutput_t, 0, i);
-      THTensor *gradWeight_i = THTensor_(newSelect)(gradWeight, 0, i);
-      THTensor *gradBias_i = NULL;
-      if(gradBias) {
-      	gradBias_i = THTensor_(newSelect)(gradBias, 0, i);
-      }
-      THNN_(SpatialDepthWiseConvolution_accGradParameters_frame)(gradOutput_i, gradWeight_i,
-                gradBias_i, finput_i, scale);
-
-      THTensor_(free)(finput_i);
-      THTensor_(free)(gradOutput_i);
-      THTensor_(free)(gradWeight_i);
-      THTensor_(free)(gradBias_i);
-    }
-
-    THTensor_(free)(gradOutput_t);
-    THTensor_(free)(finput_t);
-  }
-
-  // Copy back and transpose back
-  THTensor_(transpose)(_gradWeight, NULL, 0, 1);
-  THTensor_(resize4d)(_gradWeight, nInputPlane, nOutputPlane, kH, kW);
-  THTensor_(copy)(_gradWeight, gradWeight);
-  THTensor_(transpose)(_gradWeight, NULL, 0, 1);
-
-  if(gradBias) {
-	  THTensor_(transpose)(_gradBias, NULL, 0, 1);
-	  THTensor_(resize2d)(_gradBias, nInputPlane, nOutputPlane);
-	  THTensor_(copy)(_gradBias, gradBias);
-	  THTensor_(transpose)(_gradBias, NULL, 0, 1);
-  }
-
-  if (batch == 0) {
-    THTensor_(select)(gradOutput, NULL, 0, 0);
-    THTensor_(select)(input, NULL, 0, 0);
-    THTensor_(select)(finput, NULL, 0, 0);
-  }
-
-  THTensor_(free)(input);
-  THTensor_(free)(gradOutput);
-  THTensor_(free)(gradWeight);
-  THTensor_(free)(gradBias);
-}
-
-#endif
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index 0cf9c69..dbbf5f1 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -839,41 +839,6 @@
           int padW, int padH,
           accreal scale);
 
-TH_API void THNN_(SpatialDepthWiseConvolution_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          THTensor *weight,
-          THTensor *bias,         // [OPTIONAL]
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW, int kH,
-          int dW, int dH,
-          int padW, int padH);
-TH_API void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *weight,
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW, int kH,
-          int dW, int dH,
-          int padW, int padH);
-TH_API void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradWeight,
-          THTensor *gradBias,     // [OPTIONAL]
-          THTensor *finput,
-          THTensor *fgradInput,
-          int kW, int kH,
-          int dW, int dH,
-          int padW, int padH,
-          accreal scale);
-
 TH_API void THNN_(SpatialConvolutionLocal_updateOutput)(
           THNNState *state,
           THTensor *input,
diff --git a/torch/lib/THNN/init.c b/torch/lib/THNN/init.c
index 326e630..f4093df 100644
--- a/torch/lib/THNN/init.c
+++ b/torch/lib/THNN/init.c
@@ -200,9 +200,6 @@
 #include "generic/SpatialConvolutionMM.c"
 #include "THGenerateFloatTypes.h"
 
-#include "generic/SpatialDepthWiseConvolution.c"
-#include "THGenerateFloatTypes.h"
-
 #include "generic/SpatialConvolutionLocal.c"
 #include "THGenerateFloatTypes.h"