blob: 5b2f0e6e2339a113312d1d42b337f8175df3ee34 [file] [log] [blame]
#include "THCUNN.h"
#include "common.h"
#include "im2col.h"
void THNN_CudaSpatialDilatedConvolution_updateOutput(THCState *state,
THCudaTensor *input, THCudaTensor *output, THCudaTensor *weight,
THCudaTensor *bias, THCudaTensor *columns,
THCudaTensor *ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH) {
THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones);
if (bias) {
THCUNN_assertSameGPU(state, 2, weight, bias);
}
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
THArgCheck(weight->nDimension == 4, 4, "weight tensor must be 4D (nOutputPlane,nInputPlane,kH,kW)");
THArgCheck(!bias || weight->size[0] == bias->size[0], 4, "nOutputPlane mismatch in weight and bias");
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
// Params:
int nInputPlane = weight->size[1];
int nOutputPlane = weight->size[0];
int batch = 1;
if (input->nDimension == 3) {
THArgCheck(input->size[0] == nInputPlane, 2, "input channels and nInputPlane dont match");
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
} else {
THArgCheck(input->size[1] == nInputPlane, 2, "input channels and nInputPlane dont match");
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
if (outputWidth < 1 || outputHeight < 1)
THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small",
nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth);
// Batch size + input planes
long batchSize = input->size[0];
// Resize output
THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight, outputWidth);
// Resize temporary columns
THCudaTensor_resize2d(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Define a buffer of ones, for bias accumulation
// Note: this buffer can be shared with other modules, it only ever gets increased,
// and always contains ones.
if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
// Resize plane and fill with ones...
THCudaTensor_resize2d(state, ones, outputHeight, outputWidth);
THCudaTensor_fill(state, ones, 1);
}
// Helpers
THCudaTensor *input_n = THCudaTensor_new(state);
THCudaTensor *output_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
THCudaTensor_select(state, input_n, input, 0, elt);
THCudaTensor_select(state, output_n, output, 0, elt);
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m_ = nOutputPlane;
long n_ = outputHeight * outputWidth;
long k_ = 1;
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
if (bias) {
THCudaBlas_Sgemm(
state,
't', 'n',
n_, m_, k_,
1,
THCudaTensor_data(state, ones), k_,
THCudaTensor_data(state, bias), k_,
0,
THCudaTensor_data(state, output_n), n_
);
} else {
THCudaTensor_zero(state, output_n);
}
// Extract columns:
im2col(
THCState_getCurrentStream(state),
THCudaTensor_data(state, input_n),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW,
THCudaTensor_data(state, columns)
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = nOutputPlane;
long n = columns->size[1];
long k = nInputPlane*kH*kW;
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_Sgemm(
state,
'n', 'n',
n, m, k,
1,
THCudaTensor_data(state, columns), n,
THCudaTensor_data(state, weight), k,
1,
THCudaTensor_data(state, output_n), n
);
}
// Free
THCudaTensor_free(state, input_n);
THCudaTensor_free(state, output_n);
// Resize output
if (batch == 0) {
THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
}
}
void THNN_CudaSpatialDilatedConvolution_updateGradInput(THCState *state,
THCudaTensor *input, THCudaTensor *gradOutput,
THCudaTensor *gradInput, THCudaTensor *weight,
THCudaTensor *gradColumns,
int kW, int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH ) {
THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
gradColumns, gradInput);
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
THArgCheck(weight->nDimension == 4, 4, "weight tensor must be 4D (nOutputPlane,nInputPlane,kH,kW)");
THArgCheck(kW > 0 && kH > 0, 9, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero");
// Params
int nInputPlane = weight->size[1];
int nOutputPlane = weight->size[0];
int batch = 1;
if (input->nDimension == 3) {
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
// Batch size + input planes
long batchSize = input->size[0];
// Resize output
THCudaTensor_resize4d(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth);
// Resize temporary columns
THCudaTensor_resize2d(state, gradColumns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Helpers
THCudaTensor *gradInput_n = THCudaTensor_new(state);
THCudaTensor *gradOutput_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per sample:
THCudaTensor_select(state, gradInput_n, gradInput, 0, elt);
THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = nInputPlane*kW*kH;
long n = gradColumns->size[1];
long k = nOutputPlane;
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_Sgemm(
state,
'n', 't',
n, m, k,
1,
THCudaTensor_data(state, gradOutput_n), n,
THCudaTensor_data(state, weight), m,
0,
THCudaTensor_data(state, gradColumns), n
);
// Unpack columns back into input:
col2im(
THCState_getCurrentStream(state),
THCudaTensor_data(state, gradColumns),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW,
THCudaTensor_data(state, gradInput_n)
);
}
// Free
THCudaTensor_free(state, gradInput_n);
THCudaTensor_free(state, gradOutput_n);
// Resize output
if (batch == 0) {
THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
THCudaTensor_resize3d(state, gradInput, nInputPlane, inputHeight, inputWidth);
}
}
void THNN_CudaSpatialDilatedConvolution_accGradParameters(THCState *state,
THCudaTensor *input, THCudaTensor *gradOutput,
THCudaTensor *gradWeight, THCudaTensor *gradBias,
THCudaTensor *columns, THCudaTensor *ones,
int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, float scale) {
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
}
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
THArgCheck(gradWeight->nDimension == 4, 4, "gradWeight tensor must be 4D (nOutputPlane,nInputPlane,kH,kW)");
THArgCheck(!gradBias || gradWeight->size[0] == gradBias->size[0], 4, "nOutputPlane mismatch in gradWeight and gradBias");
THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero");
THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero");
// Params
int nInputPlane = gradWeight->size[1];
int nOutputPlane = gradWeight->size[0];
int batch = 1;
if (input->nDimension == 3) {
// Force batch
batch = 0;
THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]);
THCudaTensor_resize4d(state, gradOutput, 1, gradOutput->size[0], gradOutput->size[1], gradOutput->size[2]);
}
long inputWidth = input->size[3];
long inputHeight = input->size[2];
long outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1;
// Batch size + input planes
long batchSize = input->size[0];
// Define a buffer of ones, for bias accumulation
if (ones->nDimension != 2 || ones->size[0]*ones->size[1] < outputHeight*outputWidth) {
// Resize plane and fill with ones...
THCudaTensor_resize2d(state, ones, outputHeight, outputWidth);
THCudaTensor_fill(state, ones, 1);
}
// Resize temporary columns
THCudaTensor_resize2d(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth);
// Helpers
THCudaTensor *input_n = THCudaTensor_new(state);
THCudaTensor *gradOutput_n = THCudaTensor_new(state);
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
THCudaTensor_select(state, input_n, input, 0, elt);
THCudaTensor_select(state, gradOutput_n, gradOutput, 0, elt);
// Extract columns:
im2col(
THCState_getCurrentStream(state),
THCudaTensor_data(state, input_n),
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW,
THCudaTensor_data(state, columns)
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m = nOutputPlane;
long n = nInputPlane*kW*kH;
long k = columns->size[1];
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
THCudaBlas_Sgemm(
state,
't', 'n',
n, m, k,
scale,
THCudaTensor_data(state, columns), k,
THCudaTensor_data(state, gradOutput_n), k,
1,
THCudaTensor_data(state, gradWeight), n
);
// Do Bias:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
long m_ = nOutputPlane;
long k_ = outputHeight * outputWidth;
// Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
if (gradBias) {
THCudaBlas_Sgemv(
state,
't',
k_, m_,
scale,
THCudaTensor_data(state, gradOutput_n), k_,
THCudaTensor_data(state, ones), 1,
1,
THCudaTensor_data(state, gradBias), 1
);
}
}
// Free
THCudaTensor_free(state, input_n);
THCudaTensor_free(state, gradOutput_n);
// Resize
if (batch == 0) {
THCudaTensor_resize3d(state, gradOutput, nOutputPlane, outputHeight, outputWidth);
THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth);
}
}