Spatial Depthwise Convolution on the GPU (#3057)
* THCUNN Skeleton for Depthwise Convolution port
* implement Depthwise Convolution CUDA Kernels (handles weight parameter only, not bias)
* working kernels and bindings for forward + backward for base conv, and integration
* add support for padding
* strides for weight kernel
* dilation for weight gradient, enable for others
* add support for depthwise multiplier
* remove old depthwise conv
* rename to SpatialDepthwiseConvolution
* clean up depthwise code, add shape asserts, more constrained thread count for accgradparams
* add bias for forward for depthwise conv
* add grad_bias, move bias for forward to CUDA
* fix eligibility test to guard against transposed, properly identify depth multiplier
* add basic unit test; make depthwise conv take priority over cudnn when appropriate
* add tests for depthwise permutations
* make cuda kernels calculate positions using mul instead of div
* remove unnecessary samegpu requirement
* use accreal, test for double type
* use THAssert instead of assert
* rename to is_depthwise
* half prec support for depthwise
* make certain computation more pythonic
* flake8
diff --git a/test/test_nn.py b/test/test_nn.py
index a08bc32..81867c4 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1942,6 +1942,49 @@
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0))
+ # Very similar to test_Conv2d_naive_groups but with special care to handle
+ # the number of groups == number of input channels
+ @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
+ def test_Conv2d_depthwise_naive_groups(self):
+ types = [torch.cuda.FloatTensor, torch.cuda.DoubleTensor,
+ torch.cuda.HalfTensor]
+ precs = [1e-5, 1e-5, 1e-2]
+ for tp, prec in zip(types, precs):
+ for depth_multiplier in [1, 2]:
+ m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).type(tp)
+ i = Variable(torch.randn(2, 2, 6, 6).type(tp), requires_grad=True)
+ output = m(i)
+ grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4).type(tp)
+ output.backward(grad_output)
+
+ offset = 1 * depth_multiplier
+
+ m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).type(tp)
+ m1.weight.data = m.weight.data[:offset].clone()
+ m1.bias.data = m.bias.data[:offset].clone()
+ i1 = Variable(i.data[:, :1].contiguous(), requires_grad=True)
+ output1 = m1(i1)
+ output1.backward(grad_output[:, :offset].contiguous())
+
+ m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).type(tp)
+ m2.weight.data.copy_(m.weight.data[offset:])
+ m2.bias.data.copy_(m.bias.data[offset:])
+ i2 = Variable(i.data[:, 1:].contiguous(), requires_grad=True)
+ output2 = m2(i2)
+ output2.backward(grad_output[:, offset:].contiguous())
+
+ self.assertEqual(output, torch.cat([output1, output2], 1),
+ prec=prec)
+ self.assertEqual(i.grad.data,
+ torch.cat([i1.grad.data, i2.grad.data], 1),
+ prec=prec)
+ self.assertEqual(m.bias.grad.data,
+ torch.cat([m1.bias.grad.data,
+ m2.bias.grad.data], 0), prec=prec)
+ self.assertEqual(m.weight.grad.data,
+ torch.cat([m1.weight.grad.data,
+ m2.weight.grad.data], 0), prec=prec)
+
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
@@ -3830,6 +3873,31 @@
check_gradgrad=False,
),
dict(
+ fullname='Conv2d_depthwise',
+ constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
+ input_size=(2, 4, 6, 6),
+ ),
+ dict(
+ fullname='Conv2d_depthwise_with_multiplier',
+ constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
+ input_size=(2, 4, 6, 6),
+ ),
+ dict(
+ fullname='Conv2d_depthwise_strided',
+ constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
+ input_size=(2, 4, 6, 6),
+ ),
+ dict(
+ fullname='Conv2d_depthwise_padded',
+ constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
+ input_size=(2, 4, 6, 6),
+ ),
+ dict(
+ fullname='Conv2d_depthwise_dilated',
+ constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
+ input_size=(2, 4, 5, 5),
+ ),
+ dict(
module_name='MaxPool2d',
constructor_args=((3, 3), (2, 2), (1, 1)),
input_size=(1, 3, 7, 7),
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
index 4f4a7b5..d933443 100644
--- a/torch/csrc/autograd/functions/convolution.cpp
+++ b/torch/csrc/autograd/functions/convolution.cpp
@@ -54,6 +54,14 @@
return is_dilated;
}
+auto ConvParams::is_padded() const -> bool {
+ bool is_padded = false;
+ for (int p : padding) {
+ is_padded |= (p != 0);
+ }
+ return is_padded;
+}
+
auto ConvParams::is_output_padding_neg() const -> bool {
bool is_non_neg = false;
for (int p : output_padding) {
@@ -119,6 +127,19 @@
return false;
}
+// We currently only have depthwise support for the case where groups ==
+// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
+// a depthwise multiplier)
+auto ConvParams::is_depthwise(
+ const at::Tensor& input, const at::Tensor& weight, int groups) const -> bool {
+ return input.type().isCuda() &&
+ !transposed &&
+ input.ndimension() == 4 &&
+ input.size(1) == groups &&
+ groups > 1 && // no point if there is only a single group
+ weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels
+}
+
std::string ConvForward::name() { return "ConvForward"; }
auto ConvForward::output_size(at::Tensor& input, at::Tensor& weight) const -> std::vector<int64_t> {
@@ -220,6 +241,14 @@
return result;
}
+static std::vector<int64_t> vecToInt64(const std::vector<int>& src) {
+ std::vector<int64_t> res(src.size());
+ for (size_t i = 0; i < src.size(); i++) {
+ res[i] = static_cast<int64_t>(src[i]);
+ }
+ return res;
+}
+
static at::Tensor cat(const tensor_list& tensors, int dim) {
int num_inputs = tensors.size();
if (num_inputs == 0) {
@@ -231,7 +260,6 @@
return output;
}
-
// ConvForward implementation
auto ConvForward::apply(const variable_list& inputs) -> variable_list {
@@ -260,7 +288,16 @@
tensor_list ones(groups);
std::unique_ptr<Convolution> convolution;
- if (use_cudnn(input)) {
+ if (is_depthwise(input, weight, groups)) {
+ /* output.resize_(output_size(input, weight)); */
+
+ auto kernel_size = weight.sizes().slice(2);
+ auto stride = vecToInt64(this->stride);
+ auto padding = vecToInt64(this->padding);
+ auto dilation = vecToInt64(this->dilation);
+
+ output = at::conv_depthwise2d_forward(input, weight, kernel_size, bias, stride, padding, dilation);
+ } else if (use_cudnn(input)) {
#ifdef WITH_CUDNN
if (input.type().ID() != weight.type().ID()){
std::stringstream ss;
@@ -325,6 +362,14 @@
});
};
+// For Convolution strategies that don't implicitly handle grad_bias, we add a helper
+// function here to perform it using simple Tensor operators
+static at::Tensor compute_grad_bias(const at::Tensor& grad_output) {
+ // grad_output is in N, C, H, W, we re-shape and make contiguous
+ at::Tensor transposed = grad_output.transpose(0, 1).contiguous();
+ // sum across all of the channels and add to grad_bias
+ return transposed.view({transposed.size(0), -1}).sum(1);
+}
// ConvBackward implementation
@@ -354,6 +399,8 @@
grad_output = view4d(grad_output);
}
+
+ bool use_depthwise = this->is_depthwise(input, weight, groups);
bool use_cudnn = this->use_cudnn(input);
at::Tensor grad_input;
@@ -366,7 +413,23 @@
should_compute_output(2) && bias.defined(),
};
- if (use_cudnn) {
+ if (use_depthwise) {
+ if (output_mask[0] || output_mask[1]) {
+ auto kernel_size = weight.sizes().slice(2);
+ auto stride = vecToInt64(this->stride);
+ auto padding = vecToInt64(this->padding);
+ auto dilation = vecToInt64(this->dilation);
+
+ std::tie(grad_input, grad_weight) = at::conv_depthwise2d_backward(
+ grad_output, input, weight, kernel_size, stride, padding, dilation,
+ {output_mask[0], output_mask[1]});
+ }
+
+ // THCUNN implementation does not handle bias, so we do it ourselves
+ if (output_mask[2]) {
+ grad_bias = compute_grad_bias(grad_output);
+ }
+ } else if (use_cudnn) {
#ifdef WITH_CUDNN
if (output_mask[0]) {
grad_input = input.type().tensor();
@@ -648,14 +711,6 @@
grad_output_.data.reset();
}
-static std::vector<int64_t> vecToInt64(const std::vector<int>& src) {
- std::vector<int64_t> res(src.size());
- for (size_t i = 0; i < src.size(); i++) {
- res[i] = static_cast<int64_t>(src[i]);
- }
- return res;
-}
-
// Forward and backward functions for Tensor
static at::Tensor compute_output(
@@ -791,10 +846,7 @@
}
if (output_mask[2]) {
- // grad_output is in N, C, H, W, we re-shape and make contiguous
- at::Tensor transposed = grad_output.transpose(0, 1).contiguous();
- // sum across all of the channels and add to grad_bias
- grad_bias = transposed.view({transposed.size(0), -1}).sum(1);
+ grad_bias = compute_grad_bias(grad_output);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
diff --git a/torch/csrc/autograd/functions/convolution.h b/torch/csrc/autograd/functions/convolution.h
index be5f6ae..f8ba1bc 100644
--- a/torch/csrc/autograd/functions/convolution.h
+++ b/torch/csrc/autograd/functions/convolution.h
@@ -34,12 +34,14 @@
bool is_strided() const;
bool is_dilated() const;
+ bool is_padded() const;
bool is_output_padding_neg() const;
bool is_output_padding_big() const;
bool is_padding_neg() const;
void view1d_as_2d();
bool use_cudnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
+ bool is_depthwise(const at::Tensor& input, const at::Tensor& weight, int groups) const;
};
struct ConvForward : public ForwardFunction<>, public ConvParams, public HasSymbolic {
diff --git a/torch/lib/ATen/gen.py b/torch/lib/ATen/gen.py
index 48de0b2..3093060 100644
--- a/torch/lib/ATen/gen.py
+++ b/torch/lib/ATen/gen.py
@@ -222,6 +222,7 @@
declarations = [d
for file in cwrap_files
for d in cwrap_parser.parse(file)]
+print(nn_files)
declarations += nn_parse.run(nn_files)
declarations = preprocess_declarations.run(declarations)
for fname, env in generators.items():
diff --git a/torch/lib/ATen/nn.yaml b/torch/lib/ATen/nn.yaml
index 6d23585..3ef6b76 100644
--- a/torch/lib/ATen/nn.yaml
+++ b/torch/lib/ATen/nn.yaml
@@ -138,6 +138,10 @@
cname: SpatialConvolutionMM
buffers: [finput, fgrad_input]
+- name: conv_depthwise2d(Tensor input, Tensor weight, IntList[2] kernel_size, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] dilation=1)
+ cname: SpatialDepthwiseConvolution
+ buffers: []
+
- name: conv3d(Tensor input, Tensor weight, IntList[3] kernel_size, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0)
cname: VolumetricConvolutionMM
buffers: [finput, fgrad_input]
diff --git a/torch/lib/ATen/nn_parse.py b/torch/lib/ATen/nn_parse.py
index 7267531..95cee94 100644
--- a/torch/lib/ATen/nn_parse.py
+++ b/torch/lib/ATen/nn_parse.py
@@ -54,6 +54,15 @@
cname = thnn_function.name
output_args = []
+ # function_wrapper expects everything in a declaration to be in
+ # the base type (i.e. THTensor*), but if we pull a THCUNN only
+ # implementation, it will have THCTensor* as the arg type. So we
+ # strip the THC here before returning
+ def map_to_th_type(t):
+ if t.startswith('THC'):
+ t = t.replace('THC', 'TH')
+ return t
+
def is_output_arg(arg_name, func_name):
if arg_name == 'output' and 'updateOutput' in cname:
return True
@@ -68,7 +77,7 @@
name = arg.name
if is_output_arg(name, cname):
desc = {
- 'type': arg.type,
+ 'type': map_to_th_type(arg.type),
'name': camel_to_snake(name),
'output': True,
}
@@ -206,7 +215,7 @@
return {
'mode': 'NN',
'name': name,
- 'types': ['Float', 'Double'],
+ 'types': ['Float', 'Double', 'Half'], # Half will be stripped for CPU backend
'arguments': arguments,
'return': get_return(arguments),
'buffers': buffers,
diff --git a/torch/lib/THC/CMakeLists.txt b/torch/lib/THC/CMakeLists.txt
index a51586a..a46ec88 100644
--- a/torch/lib/THC/CMakeLists.txt
+++ b/torch/lib/THC/CMakeLists.txt
@@ -286,6 +286,7 @@
THCReduce.cuh
THCReduceAll.cuh
THCReduceApplyUtils.cuh
+ THCTensorMathReduce.cuh
THCAsmUtils.cuh
THCAtomics.cuh
THCScanUtils.cuh
diff --git a/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu b/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu
deleted file mode 100644
index 53c92ac..0000000
--- a/torch/lib/THCUNN/SpatialDepthWiseConvolution.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-#include "THCUNN.h"
-#include "common.h"
-#include "im2col.h"
-
-#include "THCHalf.h"
-#include "THCHalfAutoNumerics.cuh"
-
-#include "generic/SpatialDepthWiseConvolution.cu"
-#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu b/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu
new file mode 100644
index 0000000..3be6f8c
--- /dev/null
+++ b/torch/lib/THCUNN/SpatialDepthwiseConvolution.cu
@@ -0,0 +1,205 @@
+// updateOutput, updateGradInput Kernels ported from Sergey Zagoruyko's pyinn, which itself was a
+// port from Caffe
+
+#include "THCUNN.h"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCNumerics.cuh"
+#include "THCReduceApplyUtils.cuh"
+#include "THCSortUtils.cuh"
+#include "THCTensorMathReduce.cuh"
+#include "SharedMem.cuh"
+#include "common.h"
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionUpdateOutput(
+ const THCDeviceTensor<T, 4> input,
+ THCDeviceTensor<T, 4> output,
+ const THCDeviceTensor<T, 4> weight,
+ const THCDeviceTensor<T, 1> bias,
+ bool biasEnabled,
+ IndexType totalElements,
+ const int outputChannels,
+ const int depthwiseMultiplier,
+ const int inputWidth, const int inputHeight,
+ const int outputWidth, const int outputHeight,
+ const int kernelWidth, const int kernelHeight,
+ const int strideWidth, const int strideHeight,
+ const int padWidth, const int padHeight,
+ const int dilationWidth, const int dilationHeight)
+{
+ const int channelStride = outputHeight * outputWidth;
+ const int batchStride = outputChannels * channelStride;
+
+ for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
+ linearIndex < totalElements;
+ linearIndex += gridDim.x * blockDim.x) {
+
+ const int n = linearIndex / batchStride;
+ const int c = (linearIndex / channelStride) % outputChannels;
+ const int h = (linearIndex / outputWidth) % outputHeight;
+ const int w = linearIndex % outputWidth;
+
+ const int inputChannel = c / depthwiseMultiplier;
+ const int inputChannels = outputChannels / depthwiseMultiplier;
+
+ int weightOffset = c * kernelHeight * kernelWidth;
+
+ AccT value = biasEnabled ? ScalarConvert<T, AccT>::to(bias.data()[c]) : ScalarConvert<int, AccT>::to(0);
+ for (int kH = 0; kH < kernelHeight; ++kH) {
+ for (int kW = 0; kW < kernelWidth; ++kW) {
+ const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
+ const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
+
+ if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) {
+ const IndexType offset = ((n * inputChannels + inputChannel) * inputHeight + h_in) *
+ inputWidth + w_in;
+ value = THCNumerics<AccT>::add(
+ value,
+ THCNumerics<AccT>::mul(
+ ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
+ ScalarConvert<T, AccT>::to(input.data()[offset])));
+ }
+ ++weightOffset;
+ }
+ }
+ output.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
+ }
+}
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionUpdateGradInput(
+ const THCDeviceTensor<T, 4> gradOutput,
+ THCDeviceTensor<T, 4> gradInput,
+ const THCDeviceTensor<T, 4> weight,
+ IndexType totalElements,
+ const int inputChannels,
+ const int depthwiseMultiplier,
+ const int outputChannels,
+ const int inputWidth, const int inputHeight,
+ const int outputWidth, const int outputHeight,
+ const int kernelWidth, const int kernelHeight,
+ const int strideWidth, const int strideHeight,
+ const int padWidth, const int padHeight,
+ const int dilationWidth, const int dilationHeight)
+{
+ const int channelStride = inputHeight * inputWidth;
+ const int batchStride = inputChannels * channelStride;
+
+ for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
+ linearIndex < totalElements;
+ linearIndex += gridDim.x * blockDim.x) {
+
+ const int n = linearIndex / batchStride;
+ const int c = (linearIndex / channelStride) % inputChannels;
+ const int h = (linearIndex / inputWidth) % inputHeight;
+ const int w = linearIndex % inputWidth;
+
+ AccT value = ScalarConvert<int, AccT>::to(0);
+ for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
+ int och = (c * depthwiseMultiplier) + multiplier;
+ int weightOffset = och * kernelHeight * kernelWidth;
+ for (int kh = 0; kh < kernelHeight; ++kh) {
+ for (int kw = 0; kw < kernelWidth; ++kw) {
+ const int h_out_s = h + padHeight - kh * dilationHeight;
+ const int w_out_s = w + padWidth - kw * dilationWidth;
+
+ if (((h_out_s % strideHeight) == 0) && ((w_out_s % strideWidth) == 0)) {
+ const int h_out = h_out_s / strideHeight;
+ const int w_out = w_out_s / strideWidth;
+
+ if ((h_out >= 0) && (h_out < outputHeight)
+ && (w_out >= 0) && (w_out < outputWidth)) {
+
+ const int offset = ((n * outputChannels + och) * outputHeight + h_out)
+ * outputWidth + w_out;
+ value = THCNumerics<AccT>::add(
+ value,
+ THCNumerics<AccT>::mul(
+ ScalarConvert<T, AccT>::to(weight.data()[weightOffset]),
+ ScalarConvert<T, AccT>::to(gradOutput.data()[offset])));
+ }
+ }
+ ++weightOffset;
+ }
+ }
+ }
+ gradInput.data()[linearIndex] = ScalarConvert<AccT, T>::to(value);
+ }
+}
+
+template <typename T, typename AccT, typename IndexType>
+__global__ void spatialDepthwiseConvolutionAccGradParameters(
+ const THCDeviceTensor<T, 4> gradOutput,
+ const THCDeviceTensor<T, 4> input,
+ THCDeviceTensor<T, 4> gradWeight,
+ const int batchSize,
+ const int inputChannels,
+ const int kernelChannels,
+ const int depthwiseMultiplier,
+ IndexType blockElements,
+ const int inputWidth, const int inputHeight,
+ const int outputWidth, const int outputHeight,
+ const int kernelWidth, const int kernelHeight,
+ const int strideWidth, const int strideHeight,
+ const int padWidth, const int padHeight,
+ const int dilationWidth, const int dilationHeight)
+{
+ const int channelStride = kernelWidth * kernelHeight;
+
+ // Have to use a statically typed Shared Memory pointer
+ SharedMem<AccT> smem;
+
+ // Each Block is responsible for accumulating over a permutation of
+ // (channels x kH x kW), use blockIdx to determine which one
+ int bidx = blockIdx.x;
+ int kW = bidx % kernelWidth;
+ int kH = (bidx / kernelWidth) % kernelHeight;
+ int ch = (bidx / channelStride) % kernelChannels;
+
+ // Need to calculate which input channel is associated with this filter
+ // channel
+ int inputCh = ch / depthwiseMultiplier;
+
+ AccT grad = ScalarConvert<float, AccT>::to(0.0);
+
+ // Block-stride loop over the number of elements we need to reduce
+ for (IndexType idx = threadIdx.x; idx < blockElements; idx += blockDim.x) {
+ // Need to calculate the following: batch position, and offset into the gradOutput
+ // in height, and width. We can intuit the corresponding position in the input from
+ // the other parameters we have
+ int go_w_offset = idx % outputWidth;
+ int go_h_offset = (idx / outputWidth) % outputHeight;
+ int batch = (idx / outputWidth / outputHeight) % batchSize;
+
+ int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth;
+ int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight;
+
+ if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) {
+ int inputOffset = ((batch * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset;
+ int outputOffset = ((batch * kernelChannels + ch) * outputHeight + go_h_offset) * outputWidth + go_w_offset;
+ grad = THCNumerics<AccT>::add(
+ grad,
+ THCNumerics<AccT>::mul(
+ ScalarConvert<T, AccT>::to(input.data()[inputOffset]),
+ ScalarConvert<T, AccT>::to(gradOutput.data()[outputOffset])));
+ }
+ }
+ __syncthreads();
+
+ // At this point each thread in the block has a local gradient, which we need to
+ // accumulate prior to writing the global value
+ AccT *buf = smem.getPointer();
+ AccT tval = reduceBlock<AccT, ReduceAdd<AccT, AccT>>(
+ buf, blockDim.x, grad, ReduceAdd<AccT, AccT>(), ScalarConvert<float, AccT>::to(0));
+
+ // After reduction, first thread in the block has the gradient, so its responsible
+ // for writing it to gradWeight
+ if (threadIdx.x == 0) {
+ int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch);
+ gradWeight.data()[weightOffset] = ScalarConvert<AccT, T>::to(tval);
+ }
+}
+
+#include "generic/SpatialDepthwiseConvolution.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu b/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
deleted file mode 100644
index 1f8391d..0000000
--- a/torch/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
+++ /dev/null
@@ -1,652 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "generic/SpatialDepthWiseConvolution.cu"
-#else
-
-static inline void THNN_(SpatialDepthWiseConvolution_shapeCheck)(
- THCState *state,
- THCTensor *input, THCTensor *gradOutput,
- THCTensor *weight, THCTensor *bias,
- int kH, int kW, int dH, int dW, int padH, int padW) {
- THArgCheck(kW > 0 && kH > 0, 9,
- "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
- THArgCheck(dW > 0 && dH > 0, 11,
- "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THCUNN_argCheck(state, weight->nDimension == 4, 5, weight,
- "2D or 4D weight tensor expected, but got: %s");
-
- if (bias != NULL) {
- THCUNN_check_dim_size(state, bias, 2, 0, weight->size[0]);
- THCUNN_check_dim_size(state, bias, 2, 1, weight->size[1]);
- }
-
- int ndim = input->nDimension;
- int dimf = 0;
- int dimh = 1;
- int dimw = 2;
-
- if (ndim == 4) {
- dimf++;
- dimh++;
- dimw++;
- }
-
- THCUNN_argCheck(state, ndim == 3 || ndim == 4, 2, input,
- "3D or 4D input tensor expected but got: %s");
-
- int64_t nInputPlane = weight->size[1];
- int64_t inputHeight = input->size[dimh];
- int64_t inputWidth = input->size[dimw];
- int64_t nOutputPlane = weight->size[0];
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
-
- if (outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%d x %d x %d). "
- "Calculated output size: (%d x %d x %d). Output size is too small",
- nInputPlane,inputHeight,inputWidth,nOutputPlane*nInputPlane,outputHeight,outputWidth);
-
- THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane);
-
- if (gradOutput != NULL) {
- THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimf, nInputPlane);
- THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimh, nOutputPlane);
- THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimw, outputHeight);
- THCUNN_check_dim_size(state, gradOutput, ndim + 1, dimw + 1, outputWidth);
- }
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- THCTensor *weight,
- THCTensor *bias,
- THCTensor *columns,
- THCTensor *ones,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH) {
-
- THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones);
- if (bias) {
- THCUNN_assertSameGPU(state, 2, weight, bias);
- }
-
- // Params:
- int nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
- int nOutputPlane = weight->size[0];
- if (weight->nDimension == 2) {
- THCTensor_(resize4d)(state, weight, nOutputPlane, nInputPlane, kH, kW);
- }
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW);
-
-
- // Transpose weight & bias
- THCTensor *_weight = THCTensor_(newTranspose)(state, weight, 0, 1);
- weight = THCTensor_(newContiguous)(state, _weight);
-
- THCTensor *_bias = NULL;
- if(bias) {
- _bias = THCTensor_(newTranspose)(state, bias, 0, 1);
- bias = THCTensor_(newContiguous)(state, _bias);
- }
-
- // resize weight
- int64_t s1 = weight->size[0];
- int64_t s2 = weight->size[1];
- int64_t s3 = weight->size[2] * weight->size[3];
- weight = THCTensor_(newWithStorage3d)(state, weight->storage, weight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
- input = THCTensor_(newContiguous)(state, input);
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- }
-
- int64_t inputWidth = input->size[3];
- int64_t inputHeight = input->size[2];
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
- // Batch size + input planes
- int64_t batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize5d)(state, output, batchSize, nInputPlane, nOutputPlane, outputHeight, outputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, kW*kH, outputHeight*outputWidth);
-
- // Define a buffer of ones, for bias accumulation
- // Note: this buffer can be shared with other modules, it only ever gets increased,
- // and always contains ones.
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *output_n = THCTensor_(new)(state);
-
-
- // Helpers for DepthWiseConvolution
- THCTensor *input_i = THCTensor_(new)(state);
- THCTensor *output_i = THCTensor_(new)(state);
- THCTensor *weight_i = THCTensor_(new)(state);
-
- THCTensor *bias_i = NULL;
- if(bias) {
- bias_i = THCTensor_(new)(state);
- }
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, output_n, output, 0, elt);
-
-
- for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
- {
- // Fetch ipelt-th input plane
- THCTensor_(narrow)(state, input_i, input_n, 0, ipelt, 1);
- THCTensor_(select)(state, output_i, output_n, 0, ipelt);
- THCTensor_(select)(state, weight_i, weight, 0, ipelt);
- if (bias) {
- THCTensor_(select)(state, bias_i, bias, 0, ipelt);
- }
- // Do Bias first:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- int64_t m_ = nOutputPlane;
- int64_t n_ = outputHeight * outputWidth;
- int64_t k_ = 1;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- if (bias) {
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n_, m_, k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, ones), k_,
- THCTensor_(data)(state, bias_i), k_,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, output_i), n_
- );
- } else {
- THCTensor_(zero)(state, output_i);
- }
-
- // Extract columns:
- im2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, input_i),
- 1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- int64_t m = nOutputPlane;
- int64_t n = columns->size[1];
- int64_t k = 1*kH*kW;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 'n',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, columns), n,
- THCTensor_(data)(state, weight_i), k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, output_i), n
- );
- }
- }
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, output_n);
-
- THCTensor_(free)(state, input_i);
- THCTensor_(free)(state, output_i);
-
- THCTensor_(free)(state, weight_i);
-
- THCTensor_(free)(state, weight);
- THCTensor_(free)(state, _weight);
-
- THCTensor_(free)(state, bias_i);
- THCTensor_(free)(state, bias);
- THCTensor_(free)(state, _bias);
- // Transpose output
- THCTensor_(resize4d)(state, output, batchSize, nInputPlane * nOutputPlane, outputHeight, outputWidth);
-
- // Make a contiguous copy of output (OPTIONAL)
- // THCTensor *_output = THCTensor_(newContiguous)(state, output);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(select)(state, output, NULL, 0, 0);
- THCTensor_(select)(state, input, NULL, 0, 0);
- }
- //else
- //THCTensor_(resize5d)(state, output, batchSize, nOutputPlane, nInputPlane, outputHeight, outputWidth);
-
- // Copy output back
- // THCTensor_(freeCopyTo)(state, _output, output);
-
- THCTensor_(free)(state, input);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- THCTensor *weight,
- THCTensor *gradColumns,
- THCTensor *ones,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH) {
-
- THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
- gradColumns, gradInput);
-
- // Params:
- int nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
- int nOutputPlane = weight->size[0];
- if (weight->nDimension == 2) {
- THCTensor_(resize4d)(state, weight, nOutputPlane, nInputPlane, kH, kW);
- }
-
- gradOutput = THCTensor_(newWithTensor)(state, gradOutput);
-
- if (input->nDimension == 3) {
- if (gradOutput->nDimension == 3) {
- THCTensor_(resize4d)(state, gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
- }
- }
- else
- {
- if (gradOutput->nDimension == 4) {
- THCTensor_(resize5d)(state, gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
- }
- }
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW);
-
- // Transpose weight
- THCTensor *_weight = THCTensor_(newTranspose)(state, weight, 0, 1);
- weight = THCTensor_(newContiguous)(state, _weight);
-
- // resize weight
- int64_t s1 = weight->size[0];
- int64_t s2 = weight->size[1];
- int64_t s3 = weight->size[2] * weight->size[3];
- weight = THCTensor_(newWithStorage3d)(state, weight->storage, weight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
-
-
- input = THCTensor_(newContiguous)(state, input);
-
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- int64_t inputWidth = input->size[3];
- int64_t inputHeight = input->size[2];
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
- // Batch size + input planes
- int64_t batchSize = input->size[0];
-
- // Resize output
- THCTensor_(resize4d)(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, gradColumns, 1*kW*kH, outputHeight*outputWidth);
-
- // Helpers
- THCTensor *gradInput_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // Helpers for DepthWiseConvolution
- THCTensor *gradOutput_i = THCTensor_(new)(state);
- THCTensor *gradInput_i = THCTensor_(new)(state);
- THCTensor *weight_i = THCTensor_(new)(state);
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per sample:
- THCTensor_(select)(state, gradInput_n, gradInput, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
- {
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
-
- // Fetch ipelt-th input plane
- THCTensor_(narrow)(state, gradInput_i, gradInput_n, 0, ipelt, 1);
- THCTensor_(select)(state, gradOutput_i, gradOutput_n, 0, ipelt);
- THCTensor_(select)(state, weight_i, weight, 0, ipelt);
-
- int64_t m = 1*kW*kH;
- int64_t n = gradColumns->size[1];
- int64_t k = nOutputPlane;
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 'n', 't',
- n, m, k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradOutput_i), n,
- THCTensor_(data)(state, weight_i), m,
- ScalarConvert<int, real>::to(0),
- THCTensor_(data)(state, gradColumns), n
- );
-
- // Unpack columns back into input:
- col2im<real, accreal>(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, gradColumns),
- 1, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, gradInput_i)
- );
- }
- }
-
- // Free
- THCTensor_(free)(state, gradInput_n);
- THCTensor_(free)(state, gradOutput_n);
-
- THCTensor_(free)(state, gradInput_i);
- THCTensor_(free)(state, gradOutput_i);
- THCTensor_(free)(state, weight_i);
-
- // Resize output
- if (batch == 0) {
- THCTensor_(select)(state, gradOutput, NULL, 0, 0);
- THCTensor_(select)(state, input, NULL, 0, 0);
- THCTensor_(select)(state, gradInput, NULL, 0, 0);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
- THCTensor_(free)(state, weight);
- THCTensor_(free)(state, _weight);
-}
-
-void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradWeight,
- THCTensor *gradBias,
- THCTensor *columns,
- THCTensor *ones,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH,
- accreal scale_) {
-
- real scale = ScalarConvert<accreal, real>::to(scale_);
-
- THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
- if (gradBias) {
- THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
- }
-
- // Params
- int nInputPlane = gradWeight->nDimension == 2 ? gradWeight->size[1]/(kW*kH) : gradWeight->size[1];
- int nOutputPlane = gradWeight->size[0];
- if (gradWeight->nDimension == 2) {
- THCTensor_(resize4d)(state, gradWeight, nOutputPlane, nInputPlane, kH, kW);
- }
-
- gradOutput = THCTensor_(newWithTensor)(state, gradOutput);
- if (input->nDimension == 3) {
- if (gradOutput->nDimension == 3) {
- THCTensor_(resize4d)(state, gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
- }
- }
- else
- {
- if (gradOutput->nDimension == 4) {
- THCTensor_(resize5d)(state, gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
- }
- }
-
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW);
-
- // Transpose gradWeight & gradBias
- THCTensor_(transpose)(state, gradWeight, NULL, 0, 1);
-
-
- THCTensor *_gradBias = NULL;
- if(gradBias) {
- THCTensor_(transpose)(state, gradBias, NULL, 0, 1);
- _gradBias = gradBias;
- gradBias = THCTensor_(newContiguous)(state, gradBias);
-
- }
-
- THCTensor *_gradWeight;
-
- _gradWeight = gradWeight;
-
- gradWeight = THCTensor_(newContiguous)(state, gradWeight);
-
-
- // resize gradWeight
- int64_t s1 = gradWeight->size[0];
- int64_t s2 = gradWeight->size[1];
- int64_t s3 = gradWeight->size[2] * gradWeight->size[3];
- gradWeight = THCTensor_(newWithStorage3d)(state, gradWeight->storage, gradWeight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
- input = THCTensor_(newContiguous)(state, input);
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THCTensor_(resize4d)(state, input, 1, input->size[0], input->size[1], input->size[2]);
- THCTensor_(resize5d)(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- int64_t inputWidth = input->size[3];
- int64_t inputHeight = input->size[2];
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
-
- // Batch size + input planes
- int64_t batchSize = input->size[0];
-
- // Define a buffer of ones, for bias accumulation
- if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
- // Resize plane and fill with ones...
- THCTensor_(resize2d)(state, ones, outputHeight, outputWidth);
- THCTensor_(fill)(state, ones, ScalarConvert<int, real>::to(1));
- }
-
- // Resize temporary columns
- THCTensor_(resize2d)(state, columns, 1*kW*kH, outputHeight*outputWidth);
-
- // Helpers
- THCTensor *input_n = THCTensor_(new)(state);
- THCTensor *gradOutput_n = THCTensor_(new)(state);
-
- // Helpers for DepthWiseConvolution
- THCTensor *gradOutput_i = THCTensor_(new)(state);
- THCTensor *input_i = THCTensor_(new)(state);
- THCTensor *gradWeight_i = THCTensor_(new)(state);
-
- THCTensor *gradBias_i = NULL;
- if(gradBias) {
- gradBias_i = THCTensor_(new)(state);
- }
-
- // For each elt in batch, do:
- for (int elt = 0; elt < batchSize; elt ++) {
- // Matrix mulitply per output:
- THCTensor_(select)(state, input_n, input, 0, elt);
- THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt);
-
- for (int ipelt = 0; ipelt < nInputPlane; ipelt++)
- {
- THCTensor_(narrow)(state, input_i, input_n, 0, ipelt, 1);
- THCTensor_(select)(state, gradOutput_i, gradOutput_n, 0, ipelt);
- THCTensor_(select)(state, gradWeight_i, gradWeight, 0, ipelt);
- if (gradBias) {
- THCTensor_(select)(state, gradBias_i, gradBias, 0, ipelt);
- }
-
- // Extract columns:
- im2col(
- THCState_getCurrentStream(state),
- THCTensor_(data)(state, input_i),
- 1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
- 1, 1, THCTensor_(data)(state, columns)
- );
-
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- int64_t m = nOutputPlane;
- int64_t n = 1*kW*kH;
- int64_t k = columns->size[1];
-
- // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemm(
- #elif defined(THC_REAL_IS_HALF)
- THCudaBlas_Hgemm(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemm(
- #endif
- state,
- 't', 'n',
- n, m, k,
- scale,
- THCTensor_(data)(state, columns), k,
- THCTensor_(data)(state, gradOutput_i), k,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradWeight_i), n
- );
-
- // Do Bias:
- // M,N,K are dims of matrix A and B
- // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
- int64_t m_ = nOutputPlane;
- int64_t k_ = outputHeight * outputWidth;
-
- // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- if (gradBias) {
- #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
- #ifdef THC_REAL_IS_FLOAT
- THCudaBlas_Sgemv(
- #elif defined(THC_REAL_IS_DOUBLE)
- THCudaBlas_Dgemv(
- #endif
- state,
- 't',
- k_, m_,
- scale,
- THCTensor_(data)(state, gradOutput_i), k_,
- THCTensor_(data)(state, ones), 1,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias_i), 1
- );
- #endif
- #ifdef THC_REAL_IS_HALF
- THCudaBlas_Hgemm(
- state,
- 't', 'n',
- m_, 1, k_,
- scale,
- THCTensor_(data)(state, gradOutput_i), k_,
- THCTensor_(data)(state, ones), k_,
- ScalarConvert<int, real>::to(1),
- THCTensor_(data)(state, gradBias_i), m_
- );
- #endif
- }
- }
- }
-
-
- // Copy back and transpose back
- THCTensor_(transpose)(state, _gradWeight, NULL, 0, 1);
- THCTensor_(resize4d)(state, _gradWeight, nInputPlane, nOutputPlane, kH, kW);
- THCTensor_(copy)(state, _gradWeight, gradWeight);
- THCTensor_(transpose)(state, _gradWeight, NULL, 0, 1);
-
- if(gradBias) {
- THCTensor_(transpose)(state, _gradBias, NULL, 0, 1);
- THCTensor_(resize2d)(state, _gradBias, nInputPlane, nOutputPlane);
- THCTensor_(copy)(state, _gradBias, gradBias);
- THCTensor_(transpose)(state, _gradBias, NULL, 0, 1);
- }
-
-
- // Free
- THCTensor_(free)(state, input_n);
- THCTensor_(free)(state, gradOutput_n);
- THCTensor_(free)(state, input_i);
- THCTensor_(free)(state, gradOutput_i);
- THCTensor_(free)(state, gradWeight_i);
- THCTensor_(free)(state, gradWeight);
- THCTensor_(free)(state, gradBias_i);
- THCTensor_(free)(state, gradBias);
-
- // Resize
- if (batch == 0) {
- THCTensor_(select)(state, gradOutput, NULL, 0, 0);
- THCTensor_(select)(state, input, NULL, 0, 0);
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, gradOutput);
-}
-
-#endif
diff --git a/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu b/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu
new file mode 100644
index 0000000..98ad76e
--- /dev/null
+++ b/torch/lib/THCUNN/generic/SpatialDepthwiseConvolution.cu
@@ -0,0 +1,201 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/SpatialDepthwiseConvolution.cu"
+#else
+
+void THNN_(SpatialDepthwiseConvolution_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH)
+{
+ THCUNN_assertSameGPU(state, 3, input, output, weight);
+
+ // Only handle 4D Input Tensors for now
+ THAssert(THCTensor_(nDimension)(state, input) == 4);
+ THAssert(THCTensor_(nDimension)(state, weight) == 4);
+
+ // We assume that the input and weight Tensors are shaped properly by
+ // the caller, so we verify that here to some extent
+
+ // Weight Tensor is shape (output_channels, 1, kH, kW)
+ THAssert(weight->size[1] == 1);
+
+ // Input Tensor is shape (N, input_channels, H, W)
+ // We verify that the # of output_channels is a multiple of input_channels
+ THAssert(weight->size[0] % input->size[1] == 0);
+
+ // Bias has same # of channels as output
+ if (bias) {
+ THAssert(bias->size[0] == weight->size[0]);
+ }
+
+ // Following the behvaior of other THCUNN functions, we shape the output
+ // Tensor ourselves
+
+ int batchSize = input->size[0];
+ int height = input->size[2];
+ int width = input->size[3];
+ int outputHeight = (height + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ int outputWidth = (width + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ int outputChannels = weight->size[0];
+
+ THCTensor_(resize4d)(state, output, batchSize, outputChannels, outputHeight, outputWidth);
+
+ THCDeviceTensor<real, 4> dInput = toDeviceTensor<real, 4>(state, input);
+ THCDeviceTensor<real, 4> dWeight = toDeviceTensor<real, 4>(state, weight);
+ THCDeviceTensor<real, 4> dOutput = toDeviceTensor<real, 4>(state, output);
+ THCDeviceTensor<real, 1> dBias;
+ if (bias) {
+ dBias = toDeviceTensor<real, 1>(state, bias);
+ }
+
+ // Kernel currently relies upon all the Tensors to be contiguous
+ THAssert(dInput.isContiguous());
+ THAssert(dWeight.isContiguous());
+ THAssert(dOutput.isContiguous());
+
+ int inputChannels = input->size[1];
+ int depthwiseMultiplier = outputChannels / inputChannels;
+
+ // One thread per output value
+ int n = THCTensor_(nElement)(state, output);
+ int blocks = GET_BLOCKS(n);
+ dim3 grid(blocks);
+ dim3 block(CUDA_NUM_THREADS);
+
+ spatialDepthwiseConvolutionUpdateOutput<real, accreal, unsigned int><<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ dInput, dOutput, dWeight, dBias, bias != NULL, n, outputChannels, depthwiseMultiplier,
+ width, height, outputWidth, outputHeight,
+ kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ THCTensor *weight,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH)
+{
+ THCUNN_assertSameGPU(state, 3, gradOutput, gradInput, weight);
+
+ // Only handle 4D Input Tensors for now
+ THAssert(THCTensor_(nDimension)(state, input) == 4);
+ THAssert(THCTensor_(nDimension)(state, weight) == 4);
+ THAssert(THCTensor_(nDimension)(state, gradOutput) == 4);
+
+ // Minimal shape checking, as above
+ // Same # of elements in batch
+ THAssert(input->size[0] == gradOutput->size[0]);
+ // Same # of filters as outputChannels
+ THAssert(weight->size[0] == gradOutput->size[1]);
+
+ // Resize GradInput
+ THCTensor_(resizeAs)(state, gradInput, input);
+
+ int inputChannels = input->size[1];
+ int height = input->size[2];
+ int width = input->size[3];
+
+ int outputChannels = gradOutput->size[1];
+ int outputHeight = gradOutput->size[2];
+ int outputWidth = gradOutput->size[3];
+
+ int depthwiseMultiplier = outputChannels / inputChannels;
+
+ THCDeviceTensor<real, 4> dGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
+ THCDeviceTensor<real, 4> dGradInput = toDeviceTensor<real, 4>(state, gradInput);
+ THCDeviceTensor<real, 4> dWeight = toDeviceTensor<real, 4>(state, weight);
+
+ // Kernel currently relies upon all the Tensors to be contiguous
+ THAssert(dGradOutput.isContiguous());
+ THAssert(dGradInput.isContiguous());
+ THAssert(dWeight.isContiguous());
+
+ // One thread per gradInput value
+ int n = THCTensor_(nElement)(state, gradInput);
+ int blocks = GET_BLOCKS(n);
+ dim3 grid(blocks);
+ dim3 block(CUDA_NUM_THREADS);
+
+ spatialDepthwiseConvolutionUpdateGradInput<real, accreal, unsigned int><<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ dGradOutput, dGradInput, dWeight, n, inputChannels, depthwiseMultiplier, outputChannels, width,
+ height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH)
+{
+ THCUNN_assertSameGPU(state, 3, input, gradOutput, gradWeight);
+
+ // Only handle 4D Input Tensors for now
+ THAssert(THCTensor_(nDimension)(state, input) == 4);
+ THAssert(THCTensor_(nDimension)(state, gradOutput) == 4);
+ THAssert(THCTensor_(nDimension)(state, gradWeight) == 4);
+
+ // Minimal shape checking as above
+ // Same # of elements in batch
+ THAssert(input->size[0] == gradOutput->size[0]);
+ // Same # of filters as outputChannels
+ THAssert(gradWeight->size[0] == gradOutput->size[1]);
+
+ int batchSize = input->size[0];
+ int inputChannels = input->size[1];
+ int height = input->size[2];
+ int width = input->size[3];
+
+ int outputChannels = gradOutput->size[1];
+ int outputHeight = gradOutput->size[2];
+ int outputWidth = gradOutput->size[3];
+
+ int depthwiseMultiplier = outputChannels / inputChannels;
+
+ THCDeviceTensor<real, 4> dGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
+ THCDeviceTensor<real, 4> dInput = toDeviceTensor<real, 4>(state, input);
+ THCDeviceTensor<real, 4> dGradWeight = toDeviceTensor<real, 4>(state, gradWeight);
+
+ // Kernel currently relies upon all the Tensors to be contiguous
+ THAssert(dGradOutput.isContiguous());
+ THAssert(dInput.isContiguous());
+ THAssert(dGradWeight.isContiguous());
+
+ // We parallelize so that each block computes a single value in gradWeight
+ int blocks = outputChannels * kH * kW;
+
+ // Because each weight position is a function of convolving the gradOutput over
+ // the input, we need batchSize * outputHeight * outputWidth individual calculations
+ int n = batchSize * outputHeight * outputWidth;
+
+ // Make sure we have enough threads to perform the reduction, and use this number
+ // to create the shared memory size for the reduction
+ dim3 grid(blocks);
+ dim3 block(std::min(nextHighestPowerOf2(n), (unsigned int64_t) CUDA_NUM_THREADS));
+ int smem = block.x * sizeof(accreal);
+
+ spatialDepthwiseConvolutionAccGradParameters<real, accreal, unsigned int><<<grid, block, smem, THCState_getCurrentStream(state)>>>(
+ dGradOutput, dInput, dGradWeight, batchSize, inputChannels, outputChannels, depthwiseMultiplier, n,
+ width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+#endif
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index cabc322..aa1842f 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -664,43 +664,37 @@
int padW, int padH,
accreal scale);
-TH_API void THNN_(SpatialDepthWiseConvolution_updateOutput)(
+TH_API void THNN_(SpatialDepthwiseConvolution_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCTensor *weight,
THCTensor *bias, // [OPTIONAL]
- THCTensor *columns,
- THCTensor *ones,
int kW, int kH,
int dW, int dH,
- int padW, int padH);
+ int padW, int padH,
+ int dilationW, int dilationH);
-TH_API void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
+TH_API void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *weight,
- THCTensor *columns,
- THCTensor *ones,
int kW, int kH,
int dW, int dH,
- int padW, int padH);
+ int padW, int padH,
+ int dilationW, int dilationH);
-TH_API void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
+TH_API void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradWeight,
- THCTensor *gradBias, // [OPTIONAL]
- THCTensor *columns,
- THCTensor *ones,
int kW, int kH,
int dW, int dH,
int padW, int padH,
- accreal scale);
-
+ int dilationW, int dilationH);
TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCState *state,
diff --git a/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c b/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c
deleted file mode 100644
index 85a9666..0000000
--- a/torch/lib/THNN/generic/SpatialDepthWiseConvolution.c
+++ /dev/null
@@ -1,528 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "generic/SpatialDepthWiseConvolution.c"
-#else
-
-static inline void THNN_(SpatialDepthWiseConvolution_shapeCheck)(
- THTensor *input, THTensor *gradOutput,
- THTensor *weight, THTensor *bias,
- int kH, int kW, int dH, int dW, int padH, int padW) {
-
- THArgCheck(kW > 0 && kH > 0, 9,
- "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
- THArgCheck(dW > 0 && dH > 0, 11,
- "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THNN_ARGCHECK(weight->nDimension == 4, 5, weight,
- "2D or 4D weight tensor expected, but got: %s");
-
- if (bias != NULL) {
- THNN_CHECK_DIM_SIZE(bias, 2, 0, weight->size[0]);
- THNN_CHECK_DIM_SIZE(bias, 2, 1, weight->size[1]);
- }
-
- int ndim = input->nDimension;
- int dimf = 0;
- int dimh = 1;
- int dimw = 2;
-
- if (ndim == 4) {
- dimf++;
- dimh++;
- dimw++;
- }
-
- THNN_ARGCHECK(ndim == 3 || ndim == 4, 2, input,
- "3D or 4D input tensor expected but got: %s");
-
- int64_t nInputPlane = weight->size[1];
- int64_t inputHeight = input->size[dimh];
- int64_t inputWidth = input->size[dimw];
- int64_t nOutputPlane = weight->size[0];
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
-
- if (outputWidth < 1 || outputHeight < 1)
- THError("Given input size: (%d x %d x %d). "
- "Calculated output size: (%d x %d x %d). Output size is too small",
- nInputPlane,inputHeight,inputWidth,nOutputPlane*nInputPlane,outputHeight,outputWidth);
-
- THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane);
-
- if (gradOutput != NULL) {
- THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimf, nInputPlane);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimh, nOutputPlane);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimw, outputHeight);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim + 1, dimw + 1, outputWidth);
- }
-}
-
-static void THNN_(SpatialDepthWiseConvolution_updateOutput_frame)(
- THTensor *input,
- THTensor *output,
- THTensor *weight,
- THTensor *bias,
- THTensor *finput,
- int kW,
- int kH,
- int dW,
- int dH,
- int padW,
- int padH,
- int64_t nInputPlane,
- int64_t inputWidth,
- int64_t inputHeight,
- int64_t nOutputPlane,
- int64_t outputWidth,
- int64_t outputHeight)
-{
- int64_t i;
- THTensor *output2d;
-
- THNN_(unfolded_copy)(finput, input, kW, kH, dW, dH, padW, padH,
- nInputPlane, inputWidth, inputHeight,
- outputWidth, outputHeight);
-
- output2d = THTensor_(newWithStorage2d)(output->storage, output->storageOffset,
- nOutputPlane, -1,
- outputHeight*outputWidth, -1);
- if (bias) {
- for(i = 0; i < nOutputPlane; i++)
- THVector_(fill)
- (output->storage->data + output->storageOffset + output->stride[0] * i,
- THTensor_(get1d)(bias, i), outputHeight*outputWidth);
- } else {
- THTensor_(zero)(output);
- }
-
- THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput);
-
- THTensor_(free)(output2d);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- THTensor *weight,
- THTensor *bias,
- THTensor *finput,
- THTensor *fgradInput,
- int kW,
- int kH,
- int dW,
- int dH,
- int padW,
- int padH)
-{
- int64_t nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
- int64_t nOutputPlane = weight->size[0];
- if (weight->nDimension == 2) {
- THTensor_(resize4d)(weight, nOutputPlane, nInputPlane, kH, kW);
- }
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (input, NULL, weight, bias, kH, kW, dH, dW, padH, padW);
-
- THTensor *_weight = THTensor_(newTranspose)(weight, 0, 1);
- weight = THTensor_(newContiguous)(_weight);
-
- THTensor *_bias = NULL;
- if(bias) {
- _bias = THTensor_(newTranspose)(bias, 0, 1);
- bias = THTensor_(newContiguous)(_bias);
- }
-
- // resize weight
- int64_t s1 = weight->size[0];
- int64_t s2 = weight->size[1];
- int64_t s3 = weight->size[2] * weight->size[3];
- weight = THTensor_(newWithStorage3d)(weight->storage, weight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
- input = THTensor_(newContiguous)(input);
-
- int ndim = input->nDimension;
-
- int batch = 1;
- if (ndim == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- }
-
- int64_t inputHeight = input->size[3];
- int64_t inputWidth = input->size[2];
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
-
- int64_t T = input->size[0];
- int64_t t;
-
- THTensor_(resize5d)(output, T, nInputPlane, nOutputPlane, outputHeight, outputWidth);
- THTensor_(resize4d)(finput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
-#pragma omp parallel for private(t)
- for(t = 0; t < T; t++)
- {
- THTensor *input_t = THTensor_(newSelect)(input, 0, t);
- THTensor *output_t = THTensor_(newSelect)(output, 0, t);
- THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
-
- int64_t i;
-#pragma omp parallel for private(i)
- for(i = 0; i < nInputPlane; i++)
- {
- THTensor *weight_i = THTensor_(newSelect)(weight, 0, i);
- THTensor *input_i = THTensor_(newNarrow)(input_t, 0, i, 1);
- THTensor *output_i = THTensor_(newSelect)(output_t, 0, i);
- THTensor *finput_i = THTensor_(newSelect)(finput_t, 0, i);
- THTensor *bias_i = NULL;
- if(bias) {
- bias_i = THTensor_(newSelect)(bias, 0, i);
- }
- THNN_(SpatialDepthWiseConvolution_updateOutput_frame)
- (input_i, output_i, weight_i, bias_i, finput_i,
- kW, kH, dW, dH, padW, padH,
- 1, inputWidth, inputHeight,
- nOutputPlane, outputWidth, outputHeight);
-
- THTensor_(free)(input_i);
- THTensor_(free)(weight_i);
- THTensor_(free)(bias_i);
- THTensor_(free)(output_i);
- THTensor_(free)(finput_i);
- }
- THTensor_(free)(input_t);
- THTensor_(free)(output_t);
- THTensor_(free)(finput_t);
- }
-
- THTensor_(free)(weight);
- THTensor_(free)(_weight);
- THTensor_(free)(bias);
- THTensor_(free)(_bias);
- THTensor_(resize4d)(output, T, nInputPlane * nOutputPlane, outputHeight, outputWidth);
-
- if (batch == 0) {
- THTensor_(select)(output, NULL, 0, 0);
- THTensor_(select)(input, NULL, 0, 0);
- THTensor_(select)(finput, NULL, 0, 0);
- }
- THTensor_(free)(input);
-}
-
-static void THNN_(SpatialDepthWiseConvolution_updateGradInput_frame)(
- THTensor *gradInput,
- THTensor *gradOutput,
- THTensor *weight,
- THTensor *fgradInput,
- int kW,
- int kH,
- int dW,
- int dH,
- int padW,
- int padH)
-{
- THTensor *gradOutput2d = THTensor_(newWithStorage2d)
- (gradOutput->storage, gradOutput->storageOffset,
- gradOutput->size[0], -1,
- gradOutput->size[1]*gradOutput->size[2], -1);
- THTensor_(addmm)(fgradInput, 0, fgradInput, 1, weight, gradOutput2d);
- THTensor_(free)(gradOutput2d);
-
- THTensor_(zero)(gradInput);
-
- THNN_(unfolded_acc)(fgradInput, gradInput, kW, kH, dW, dH,
- padW, padH,
- gradInput->size[0], gradInput->size[2], gradInput->size[1],
- gradOutput->size[2], gradOutput->size[1]);
-}
-
-void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *weight,
- THTensor *finput,
- THTensor *fgradInput,
- int kW,
- int kH,
- int dW,
- int dH,
- int padW,
- int padH)
-{
- int64_t nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kH*kW) : weight->size[1];
- int64_t nOutputPlane = weight->size[0];
- if (weight->nDimension == 2) {
- THTensor_(resize4d)(weight, nOutputPlane, nInputPlane, kH, kW);
- }
- gradOutput = THTensor_(newWithTensor)(gradOutput);
-
- if (input->nDimension == 3) {
- if (gradOutput->nDimension == 3) {
- THTensor_(resize4d)(gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
- }
- }
- else
- {
- if (gradOutput->nDimension == 4) {
- THTensor_(resize5d)(gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
- }
- }
-
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW);
-
- THTensor *_weight = THTensor_(newTranspose)(weight, 0, 1);
- weight = THTensor_(newContiguous)(_weight);
-
-
- // resize weight
- int64_t s1 = weight->size[0];
- int64_t s2 = weight->size[1];
- int64_t s3 = weight->size[2] * weight->size[3];
- weight = THTensor_(newWithStorage3d)(weight->storage, weight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
- input = THTensor_(newContiguous)(input);
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- int64_t inputHeight = input->size[3];
- int64_t inputWidth = input->size[2];
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
-
- int64_t T = input->size[0];
- int64_t t;
-
- THTensor_(resizeAs)(gradInput, input);
- THTensor_(resize4d)(fgradInput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
- // depending on the BLAS library, fgradInput (result tensor) might
- // be left uninitialized on zero alpha, which might lead to weird behavior
- // hence, to be safe, zero it
- THTensor_(zero)(fgradInput);
-
-
-
-#pragma omp parallel for private(t)
- for(t = 0; t < T; t++)
- {
- THTensor *gradInput_t = THTensor_(newSelect)(gradInput, 0, t);
- THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
- THTensor *fgradInput_t = THTensor_(newSelect)(fgradInput, 0, t);
-
-
- int64_t i;
-#pragma omp parallel for private(i)
- for(i = 0; i < nInputPlane; i++)
- {
- THTensor *weight_i = THTensor_(newSelect)(weight, 0, i);
- THTensor *gradInput_i = THTensor_(newNarrow)(gradInput_t, 0, i, 1);
- THTensor *gradOutput_i = THTensor_(newSelect)(gradOutput_t, 0, i);
- THTensor *fgradInput_i = THTensor_(newSelect)(fgradInput_t, 0, i);
-
- THTensor_(transpose)(weight_i, weight_i, 0, 1);
-
- THNN_(SpatialDepthWiseConvolution_updateGradInput_frame)(gradInput_i, gradOutput_i,
- weight_i, fgradInput_i,
- kW, kH, dW, dH, padW, padH);
-
- THTensor_(free)(gradInput_i);
- THTensor_(free)(weight_i);
- THTensor_(free)(gradOutput_i);
- THTensor_(free)(fgradInput_i);
- }
-
- THTensor_(free)(gradInput_t);
- THTensor_(free)(gradOutput_t);
- THTensor_(free)(fgradInput_t);
- }
-
- if (batch == 0) {
- THTensor_(select)(gradOutput, NULL, 0, 0);
- THTensor_(select)(input, NULL, 0, 0);
- THTensor_(select)(gradInput, NULL, 0, 0);
- THTensor_(select)(fgradInput, NULL, 0, 0);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
- THTensor_(free)(weight);
- THTensor_(free)(_weight);
-}
-
-static void THNN_(SpatialDepthWiseConvolution_accGradParameters_frame)(
- THTensor *gradOutput,
- THTensor *gradWeight,
- THTensor *gradBias,
- THTensor *finput,
- accreal scale)
-{
- int64_t i;
- THTensor *gradOutput2d = THTensor_(newWithStorage2d)
- (gradOutput->storage, gradOutput->storageOffset,
- gradOutput->size[0], -1,
- gradOutput->size[1]*gradOutput->size[2], -1);
-
- THTensor_(transpose)(finput, finput, 0, 1);
- THTensor_(addmm)(gradWeight, 1, gradWeight, scale, gradOutput2d, finput);
- THTensor_(transpose)(finput, finput, 0, 1);
-
- if (gradBias) {
- for(i = 0; i < gradBias->size[0]; i++)
- {
- int64_t k;
- real sum = 0;
- real *data = gradOutput2d->storage->data + gradOutput2d->storageOffset + i*gradOutput2d->stride[0];
- for(k = 0; k < gradOutput2d->size[1]; k++)
- sum += data[k];
- (gradBias->storage->data + gradBias->storageOffset)[i] += scale*sum;
- }
- }
-
- THTensor_(free)(gradOutput2d);
-}
-
-void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradWeight,
- THTensor *gradBias,
- THTensor *finput,
- THTensor *fgradInput,
- int kW,
- int kH,
- int dW,
- int dH,
- int padW,
- int padH,
- accreal scale)
-{
- int64_t nInputPlane = gradWeight->nDimension == 2 ? gradWeight->size[1]/(kH*kW) : gradWeight->size[1];
- int64_t nOutputPlane = gradWeight->size[0];
- if (gradWeight->nDimension == 2) {
- THTensor_(resize4d)(gradWeight, nOutputPlane, nInputPlane, kH, kW);
- }
-
- gradOutput = THTensor_(newWithTensor)(gradOutput);
- if (input->nDimension == 3) {
- if (gradOutput->nDimension == 3) {
- THTensor_(resize4d)(gradOutput, nInputPlane, nOutputPlane, gradOutput->size[1], gradOutput->size[2]);
- }
- }
- else
- {
- if (gradOutput->nDimension == 4) {
- THTensor_(resize5d)(gradOutput, gradOutput->size[0], nInputPlane, nOutputPlane, gradOutput->size[2], gradOutput->size[3]);
- }
- }
-
-
- THNN_(SpatialDepthWiseConvolution_shapeCheck)
- (input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW);
-
- // Transpose gradWeight & gradBias
- THTensor_(transpose)(gradWeight, NULL, 0, 1);
- THTensor *_gradWeight;
- _gradWeight = gradWeight;
- gradWeight = THTensor_(newContiguous)(gradWeight);
-
- THTensor *_gradBias = NULL;
- if(gradBias) {
- THTensor_(transpose)(gradBias, NULL, 0, 1);
- _gradBias = gradBias;
- gradBias = THTensor_(newContiguous)(gradBias);
- }
-
- // resize gradWeight
- int64_t s1 = gradWeight->size[0];
- int64_t s2 = gradWeight->size[1];
- int64_t s3 = gradWeight->size[2] * gradWeight->size[3];
- gradWeight = THTensor_(newWithStorage3d)(gradWeight->storage, gradWeight->storageOffset,
- s1, -1, s2, -1, s3, -1);
-
- input = THTensor_(newContiguous)(input);
-
-
- int batch = 1;
- if (input->nDimension == 3) {
- // Force batch
- batch = 0;
- THTensor_(resize4d)(input, 1, input->size[0], input->size[1], input->size[2]);
- THTensor_(resize5d)(gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]);
- }
-
- int64_t inputHeight = input->size[3];
- int64_t inputWidth = input->size[2];
- int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
- int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
-
- int64_t T = input->size[0];
- int64_t t;
- THTensor_(resize4d)(finput, T, nInputPlane, kW*kH*1, outputHeight*outputWidth);
-
- for(t = 0; t < T; t++)
- {
- THTensor *gradOutput_t = THTensor_(newSelect)(gradOutput, 0, t);
- THTensor *finput_t = THTensor_(newSelect)(finput, 0, t);
- int64_t i;
-#pragma omp parallel for private(i)
- for(i = 0; i < nInputPlane; i++)
- {
- THTensor *finput_i = THTensor_(newSelect)(finput_t, 0, i);
- THTensor *gradOutput_i = THTensor_(newSelect)(gradOutput_t, 0, i);
- THTensor *gradWeight_i = THTensor_(newSelect)(gradWeight, 0, i);
- THTensor *gradBias_i = NULL;
- if(gradBias) {
- gradBias_i = THTensor_(newSelect)(gradBias, 0, i);
- }
- THNN_(SpatialDepthWiseConvolution_accGradParameters_frame)(gradOutput_i, gradWeight_i,
- gradBias_i, finput_i, scale);
-
- THTensor_(free)(finput_i);
- THTensor_(free)(gradOutput_i);
- THTensor_(free)(gradWeight_i);
- THTensor_(free)(gradBias_i);
- }
-
- THTensor_(free)(gradOutput_t);
- THTensor_(free)(finput_t);
- }
-
- // Copy back and transpose back
- THTensor_(transpose)(_gradWeight, NULL, 0, 1);
- THTensor_(resize4d)(_gradWeight, nInputPlane, nOutputPlane, kH, kW);
- THTensor_(copy)(_gradWeight, gradWeight);
- THTensor_(transpose)(_gradWeight, NULL, 0, 1);
-
- if(gradBias) {
- THTensor_(transpose)(_gradBias, NULL, 0, 1);
- THTensor_(resize2d)(_gradBias, nInputPlane, nOutputPlane);
- THTensor_(copy)(_gradBias, gradBias);
- THTensor_(transpose)(_gradBias, NULL, 0, 1);
- }
-
- if (batch == 0) {
- THTensor_(select)(gradOutput, NULL, 0, 0);
- THTensor_(select)(input, NULL, 0, 0);
- THTensor_(select)(finput, NULL, 0, 0);
- }
-
- THTensor_(free)(input);
- THTensor_(free)(gradOutput);
- THTensor_(free)(gradWeight);
- THTensor_(free)(gradBias);
-}
-
-#endif
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index 0cf9c69..dbbf5f1 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -839,41 +839,6 @@
int padW, int padH,
accreal scale);
-TH_API void THNN_(SpatialDepthWiseConvolution_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- THTensor *weight,
- THTensor *bias, // [OPTIONAL]
- THTensor *finput,
- THTensor *fgradInput,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH);
-TH_API void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *weight,
- THTensor *finput,
- THTensor *fgradInput,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH);
-TH_API void THNN_(SpatialDepthWiseConvolution_accGradParameters)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradWeight,
- THTensor *gradBias, // [OPTIONAL]
- THTensor *finput,
- THTensor *fgradInput,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH,
- accreal scale);
-
TH_API void THNN_(SpatialConvolutionLocal_updateOutput)(
THNNState *state,
THTensor *input,
diff --git a/torch/lib/THNN/init.c b/torch/lib/THNN/init.c
index 326e630..f4093df 100644
--- a/torch/lib/THNN/init.c
+++ b/torch/lib/THNN/init.c
@@ -200,9 +200,6 @@
#include "generic/SpatialConvolutionMM.c"
#include "THGenerateFloatTypes.h"
-#include "generic/SpatialDepthWiseConvolution.c"
-#include "THGenerateFloatTypes.h"
-
#include "generic/SpatialConvolutionLocal.c"
#include "THGenerateFloatTypes.h"