| #include "THCTensorConv.h" | |
| #include "THCGeneral.h" | |
| #include <stdio.h> | |
| /* | |
| * Description: | |
| * This code provides convolutions and xcorrelations that are API compatible with | |
| * the ones in THLabConv. | |
| * | |
| * History: | |
| * Sept 11, 2011, 11:59PM - Clement Farabet - Optimized RevConv by a good x2 | |
| * July 22, 2011, 8:38PM - Clement Farabet - All Valid/Full/XCORR/CONV implemented | |
| * July 22, 2011, 4:00PM - Clement Farabet - Rewrote for loop to insure memory coalescing | |
| * July 21, 2011, 11:21PM - Clement Farabet - Creation, based conv2d routine | |
| */ | |
| #define CUDA_SHARED_MEM_SIZE (12*1024-32) // this is given by nVidia: max shared mem per block | |
| /* | |
| * Description: | |
| * base conv2D routine: 3D input, 3D output, 4D kernel | |
| * | |
| * - all chunks of data should be contiguous | |
| * - the swapkernel flag can be used to generate a conv2 instead of xcorr2 | |
| * - the templated kernel size is useful to generate code that's 2x faster | |
| * but can be set to 0 to allow arbitrary kernel sizes | |
| */ | |
| template <bool swapkernel, int T_kernel_h, int T_kernel_w> | |
| __global__ void conv2generic(float *input, float *kernel, float *output, | |
| int input_n, int input_h, int input_w, | |
| int kernel_n, int kernel_h, int kernel_w, | |
| int stride_h, int stride_w) | |
| { | |
| // output dimensions | |
| int output_h = (input_h - kernel_h) / stride_h + 1; | |
| int output_w = (input_w - kernel_w) / stride_w + 1; | |
| // xcorr or conv | |
| int koffset = swapkernel ? kernel_w*kernel_h-1 : 0; | |
| // nb outputs | |
| int output_n = kernel_n / input_n; | |
| // generate offsets according to block/thread ids | |
| int xx_start = threadIdx.x; | |
| int xx_end = output_w; | |
| int xx_step = blockDim.x; | |
| int yy_start = blockDim.y*blockIdx.y + threadIdx.y; | |
| int yy_end = output_h; | |
| int yy_step = blockDim.y*gridDim.y; | |
| int oo_start = blockIdx.x; | |
| int oo_end = oo_start+1; | |
| int ii_start = (blockIdx.x / output_n) * input_n; | |
| int ii_end = ii_start + input_n; | |
| // nb threads, unique thread id | |
| int tid = blockDim.x*blockDim.y*threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x; | |
| int nthreads = blockDim.x * blockDim.y * blockDim.z; | |
| // iterators | |
| int oo, ii, xx, yy, kx, ky, kk; | |
| // do the kernels fit in shared mem ? | |
| if (input_n*kernel_w*kernel_h <= CUDA_SHARED_MEM_SIZE) { | |
| // put the kernel in shared memory | |
| __shared__ float shared_kernel[CUDA_SHARED_MEM_SIZE]; | |
| // first thread of each block does the copy | |
| for (kk = tid; kk < kernel_w*kernel_h*input_n; kk += nthreads) { | |
| shared_kernel[kk] = kernel[input_n*kernel_w*kernel_h*(oo_start % output_n) + kk]; | |
| } | |
| __syncthreads(); | |
| // templated kernel size | |
| if ((T_kernel_w > 0) && (T_kernel_h > 0)) { | |
| // unrolled convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for(ii = ii_start; ii < ii_end; ii++) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dimensions... (between input image and the mask) | |
| float *input_p = input + ii*input_h*input_w + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w + xx; | |
| float *kernel_p = shared_kernel + (ii % input_n)*kernel_w*kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| #pragma unroll | |
| for(ky = 0; ky < T_kernel_h; ky++) { | |
| #pragma unroll | |
| for(kx = 0; kx < T_kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| #pragma unroll | |
| for(ky = 0; ky < T_kernel_h; ky++) { | |
| #pragma unroll | |
| for(kx = 0; kx < T_kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } else { | |
| // default convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for(ii = ii_start; ii < ii_end; ii++) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dimensions... (between input image and the mask) | |
| float *input_p = input + ii*input_h*input_w + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w + xx; | |
| float *kernel_p = shared_kernel + (ii % input_n) * kernel_w * kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else { // not enough shared mem for kernels, simply stream them | |
| // convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for(ii = ii_start; ii < ii_end; ii++) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dimensions... (between input image and the mask) | |
| float *input_p = input + ii*input_h*input_w + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w + xx; | |
| float *kernel_p = kernel + ((oo % output_n) * input_n + (ii % input_n))*kernel_w*kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| /* | |
| * Description: | |
| * base conv2D routine with reversed stride: 3D input, 4D output, 3D kernel | |
| * this is useful for computing gradients with respect to kernels, where: | |
| * input=input, kernel=gradOutput, output=gradWeight | |
| * | |
| * - all chunks of data should be contiguous | |
| * - the swapkernel flag can be used to generate a conv2 instead of xcorr2 | |
| */ | |
| __global__ void conv2genericrev(float *input, float *kernel, float *output, | |
| int input_n, int input_h, int input_w, | |
| int kernel_n, int kernel_h, int kernel_w, | |
| float alpha, int stride_h, int stride_w) | |
| { | |
| // output dimensions | |
| int output_h = input_h - (kernel_h - 1) * stride_h; | |
| int output_w = input_w - (kernel_w - 1) * stride_w; | |
| // this thread only processes one output, defined by the block Ids | |
| int kk = blockIdx.x; | |
| int ii = blockIdx.y; | |
| // batch id | |
| int batch = threadIdx.z; | |
| // kernel id | |
| int kid = threadIdx.x; | |
| int nkids = blockDim.x; | |
| // thread ID | |
| int tid = kid + batch*blockDim.x; | |
| int nthreads = blockDim.x * blockDim.z; | |
| // one thread only sees one output | |
| output = output + (kk * input_n + ii) * output_h*output_w; | |
| // put the output in shared memory | |
| __shared__ float shared_output[CUDA_SHARED_MEM_SIZE]; | |
| // generate tid outputs in shared memory | |
| float *output_s = shared_output + tid*output_w*output_h; | |
| // convolution loop | |
| int xx, yy, kx, ky; | |
| yy = threadIdx.y; | |
| float *output_p = output_s + yy * output_w; | |
| for(xx=0; xx<output_w; xx++) { | |
| // Dot product in two dimensions... (between input image and kernel) | |
| float *input_p = input + (ii + batch*input_n)*input_h*input_w + yy*stride_h*input_w + xx*stride_w; | |
| float *kernel_p = kernel + (kk + batch*kernel_n)*kernel_w*kernel_h; | |
| float sum = 0; | |
| for(ky=0; ky<kernel_h; ky++) { | |
| for(kx=kid; kx<kernel_w; kx+=nkids) { | |
| sum += input_p[kx]*kernel_p[kx]; | |
| } | |
| input_p += input_w; | |
| kernel_p += kernel_w; | |
| } | |
| *(output_p++) = sum; | |
| } | |
| __syncthreads(); | |
| // reduce and write back | |
| if (yy == 0) { | |
| // reduce outputs | |
| for (int k=1; k<nthreads; k++) { | |
| for (int i=tid; i<output_w*output_h; i+=nthreads) { | |
| shared_output[i] += shared_output[k*output_h*output_w + i]; | |
| } | |
| } | |
| __syncthreads(); | |
| // add existing output, and write back | |
| for (int i=tid; i<output_w*output_h; i+=nthreads) { | |
| output[i] += alpha*shared_output[i]; | |
| } | |
| } | |
| } | |
| // A helper macro for the common pattern of checking the input | |
| // rows/columns for a small number of values, specializing the kernel | |
| // template paremeters if rows/columns are equal and small, and | |
| // otherwise just passing zero to the kernel. | |
| #define FOR_KERNEL_SPECIALIZED_DIMENSION(ROWS, COLUMNS, KERNEL) \ | |
| if ((ROWS) == (COLUMNS)) { \ | |
| switch ((ROWS)) { \ | |
| case 3: { KERNEL(3); break; } \ | |
| case 4: { KERNEL(4); break; } \ | |
| case 5: { KERNEL(5); break; } \ | |
| case 6: { KERNEL(6); break; } \ | |
| case 7: { KERNEL(7); break; } \ | |
| case 8: { KERNEL(8); break; } \ | |
| case 9: { KERNEL(9); break; } \ | |
| case 10: { KERNEL(10); break; } \ | |
| case 11: { KERNEL(11); break; } \ | |
| case 12: { KERNEL(12); break; } \ | |
| case 13: { KERNEL(13); break; } \ | |
| default: { KERNEL(0); break; } \ | |
| } \ | |
| } else { \ | |
| KERNEL(0); \ | |
| } | |
| /* | |
| * API-compatible with THRealTensor_conv2Dmv | |
| * 3D input, 4D kernel, 3D output | |
| * matrix vector product like: y <- Ax + beta*y | |
| */ | |
| THC_API void THCudaTensor_conv2Dmv(THCudaTensor *output, float beta, THCudaTensor *input, | |
| THCudaTensor *kernel, long srow, long scol, const char *type) | |
| { | |
| long nInputPlane, nInputRows, nInputCols; | |
| long nKernelRows, nKernelCols; | |
| long nOutputPlane, nOutputRows, nOutputCols; | |
| THArgCheck(kernel->nDimension == 4 , 4, "kernel: 4D Tensor expected"); | |
| THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); | |
| THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); | |
| THArgCheck(type[0] == 'v' || type[0] == 'f', 7, "type of convolution can 'v' or 'f'"); | |
| THArgCheck(type[1] == 'c' || type[1] == 'x', 7, "type of convolution can 'x' or 'c'"); | |
| input = THCudaTensor_newContiguous(input); | |
| kernel = THCudaTensor_newContiguous(kernel); | |
| nInputPlane = input->size[0]; | |
| nInputRows = input->size[1]; | |
| nInputCols = input->size[2]; | |
| nKernelRows = kernel->size[2]; | |
| nKernelCols = kernel->size[3]; | |
| nOutputPlane = kernel->size[0]; | |
| THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes"); | |
| THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *type == 'f', 2, | |
| "conv2Dmv : Input image is smaller than kernel"); | |
| if (*type == 'f') { | |
| // output dims | |
| nOutputRows = (nInputRows - 1) * srow + nKernelRows; | |
| nOutputCols = (nInputCols - 1) * scol + nKernelCols; | |
| // use temp buffer | |
| THCudaTensor *inputP = THCudaTensor_new(); | |
| // create a zero-padded input | |
| long nInputRowsPadded = (nOutputRows - 1) * srow + nKernelRows; | |
| long nInputColsPadded = (nOutputCols - 1) * scol + nKernelCols; | |
| THCudaTensor_resize3d(inputP, nInputPlane, nInputRowsPadded, nInputColsPadded); | |
| THCudaTensor_zero(inputP); | |
| THCudaTensor *centered = THCudaTensor_new(); | |
| THCudaTensor_narrow(centered, inputP, 2, nKernelCols-1, nInputCols); | |
| THCudaTensor_narrow(centered, NULL, 1, nKernelRows-1, nInputRows); | |
| THCudaTensor_copy(centered, input); | |
| THCudaTensor_free(centered); | |
| // remap input to newly created tensor | |
| THCudaTensor_free(input); | |
| input = inputP; | |
| nInputRows = nInputRowsPadded; | |
| nInputCols = nInputColsPadded; | |
| } else { // 'v' | |
| // output dims | |
| nOutputRows = (nInputRows - nKernelRows) / srow + 1; | |
| nOutputCols = (nInputCols - nKernelCols) / scol + 1; | |
| } | |
| long nelem = THCudaTensor_nElement(output); | |
| THCudaTensor_resize3d(output, nOutputPlane, nOutputRows, nOutputCols); | |
| if (beta == 0 || nelem != THCudaTensor_nElement(output)) { | |
| THCudaTensor_zero(output); | |
| } else if (beta != 1) { | |
| THCudaTensor_mul(output, output, beta); | |
| } | |
| float *input_data = THCudaTensor_data(input); | |
| float *weight_data = THCudaTensor_data(kernel); | |
| float *output_data = THCudaTensor_data(output); | |
| // cuda blocks & threads: | |
| int yblocks = (int)(16L / nOutputPlane); | |
| yblocks = yblocks < 1 ? 1 : yblocks; | |
| dim3 blocks(nOutputPlane,yblocks); | |
| dim3 threads(32,8); | |
| // convolution: xcorr2 or conv2 | |
| if (type[1] == 'x') { | |
| #define X_CONV_KERNEL(dim) \ | |
| conv2generic <false, (dim), (dim)> <<<blocks, threads>>> ( \ | |
| input_data, weight_data, output_data, \ | |
| nInputPlane, nInputRows, nInputCols, \ | |
| nOutputPlane*nInputPlane, nKernelRows, nKernelCols, \ | |
| srow, scol); | |
| FOR_KERNEL_SPECIALIZED_DIMENSION(nKernelRows, nKernelCols, X_CONV_KERNEL); | |
| #undef X_CONV_KERNEL | |
| } else { // 'c' | |
| #define C_CONV_KERNEL(dim) \ | |
| conv2generic <true, (dim), (dim)> <<<blocks, threads>>> ( \ | |
| input_data, weight_data, output_data, \ | |
| nInputPlane, nInputRows, nInputCols, \ | |
| nOutputPlane*nInputPlane, nKernelRows, nKernelCols, \ | |
| srow, scol); | |
| FOR_KERNEL_SPECIALIZED_DIMENSION(nKernelRows, nKernelCols, C_CONV_KERNEL); | |
| #undef C_CONV_KERNEL | |
| } | |
| // clean | |
| if (*type != 'f') THCudaTensor_free(input); | |
| THCudaTensor_free(kernel); | |
| // check for errors | |
| cudaError_t err = cudaGetLastError(); | |
| if (err != cudaSuccess) { | |
| printf("error in conv2Dmv: %s\n", cudaGetErrorString(err)); | |
| THError("aborting"); | |
| } | |
| } | |
| /* | |
| * API-compatible with THRealTensor_conv2Dmm | |
| * 4D input, 4D kernel, 4D output | |
| * matrix vector product like: y <- Ax + beta*y | |
| */ | |
| THC_API void THCudaTensor_conv2Dmm(THCudaTensor *output, float beta, THCudaTensor *input, | |
| THCudaTensor *kernel, long srow, long scol, const char *type) | |
| { | |
| long nbatch, nInputPlane, nInputRows, nInputCols; | |
| long nKernelRows, nKernelCols; | |
| long nOutputPlane, nOutputRows, nOutputCols; | |
| THArgCheck(kernel->nDimension == 4 , 4, "kernel: 4D Tensor expected"); | |
| THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); | |
| THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); | |
| THArgCheck(type[0] == 'v' || type[0] == 'f', 7, "type of convolution can 'v' or 'f'"); | |
| THArgCheck(type[1] == 'c' || type[1] == 'x', 7, "type of convolution can 'x' or 'c'"); | |
| input = THCudaTensor_newContiguous(input); | |
| kernel = THCudaTensor_newContiguous(kernel); | |
| nbatch = input->size[0]; | |
| nInputPlane = input->size[1]; | |
| nInputRows = input->size[2]; | |
| nInputCols = input->size[3]; | |
| nKernelRows = kernel->size[2]; | |
| nKernelCols = kernel->size[3]; | |
| nOutputPlane = kernel->size[0]; | |
| THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes"); | |
| THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *type == 'f', 2, | |
| "conv2Dmm : Input image is smaller than kernel"); | |
| if (*type == 'f') { | |
| // output dims | |
| nOutputRows = (nInputRows - 1) * srow + nKernelRows; | |
| nOutputCols = (nInputCols - 1) * scol + nKernelCols; | |
| // use temp buffer | |
| THCudaTensor *inputP = THCudaTensor_new(); | |
| // create a zero-padded input | |
| long nInputRowsPadded = (nOutputRows - 1) * srow + nKernelRows; | |
| long nInputColsPadded = (nOutputCols - 1) * scol + nKernelCols; | |
| THCudaTensor_resize4d(inputP, nbatch, nInputPlane, nInputRowsPadded, nInputColsPadded); | |
| THCudaTensor_zero(inputP); | |
| THCudaTensor *centered = THCudaTensor_new(); | |
| THCudaTensor_narrow(centered, inputP, 3, nKernelCols-1, nInputCols); | |
| THCudaTensor_narrow(centered, NULL, 2, nKernelRows-1, nInputRows); | |
| THCudaTensor_copy(centered, input); | |
| THCudaTensor_free(centered); | |
| // remap input to newly created tensor | |
| THCudaTensor_free(input); | |
| input = inputP; | |
| nInputRows = nInputRowsPadded; | |
| nInputCols = nInputColsPadded; | |
| } else { // 'v' | |
| // output dims | |
| nOutputRows = (nInputRows - nKernelRows) / srow + 1; | |
| nOutputCols = (nInputCols - nKernelCols) / scol + 1; | |
| } | |
| long nelem = THCudaTensor_nElement(output); | |
| THCudaTensor_resize4d(output, nbatch, nOutputPlane, nOutputRows, nOutputCols); | |
| if (beta == 0 || nelem != THCudaTensor_nElement(output)) { | |
| THCudaTensor_zero(output); | |
| } else if (beta != 1) { | |
| THCudaTensor_mul(output, output, beta); | |
| } | |
| float *input_data = THCudaTensor_data(input); | |
| float *weight_data = THCudaTensor_data(kernel); | |
| float *output_data = THCudaTensor_data(output); | |
| // cuda blocks & threads: | |
| int yblocks = (int)(16L / nOutputPlane); | |
| yblocks = yblocks < 1 ? 1 : yblocks; | |
| dim3 blocks(nOutputPlane*nbatch,yblocks); | |
| dim3 threads(32,8); | |
| // convolution: xcorr2 or conv2 | |
| if (type[1] == 'x') { | |
| #define X_CONV_KERNEL(dim) \ | |
| conv2generic <false, (dim), (dim)> <<<blocks, threads>>> ( \ | |
| input_data, weight_data, output_data, \ | |
| nInputPlane, nInputRows, nInputCols, \ | |
| nOutputPlane*nInputPlane, nKernelRows, nKernelCols, \ | |
| srow, scol); | |
| FOR_KERNEL_SPECIALIZED_DIMENSION(nKernelCols, nKernelRows, X_CONV_KERNEL); | |
| #undef X_CONV_KERNEL | |
| } else { // 'c' | |
| #define C_CONV_KERNEL(dim) \ | |
| conv2generic <true, (dim), (dim)> <<<blocks, threads>>> ( \ | |
| input_data, weight_data, output_data, \ | |
| nInputPlane, nInputRows, nInputCols, \ | |
| nOutputPlane*nInputPlane, nKernelRows, nKernelCols, \ | |
| srow, scol); \ | |
| FOR_KERNEL_SPECIALIZED_DIMENSION(nKernelCols, nKernelRows, C_CONV_KERNEL); | |
| #undef C_CONV_KERNEL | |
| } | |
| // clean | |
| if (*type != 'f') THCudaTensor_free(input); | |
| THCudaTensor_free(kernel); | |
| // check for errors | |
| cudaError_t err = cudaGetLastError(); | |
| if (err != cudaSuccess) { | |
| cudaDeviceProp deviceProp; | |
| cudaGetDeviceProperties(&deviceProp, 0); | |
| printf("error in conv2Dmm: %s\n", cudaGetErrorString(err)); | |
| printf("requested grid size: %dx%dx%d, max allowed: %dx%dx%d\n", | |
| blocks.x, blocks.y, blocks.z, | |
| deviceProp.maxGridSize[0], deviceProp.maxGridSize[1], deviceProp.maxGridSize[2]); | |
| printf("requested block size: %dx%dx%d, max allowed: %dx%dx%d\n", | |
| threads.x, threads.y, threads.z, | |
| deviceProp.maxThreadsDim[0], deviceProp.maxThreadsDim[1], deviceProp.maxThreadsDim[2]); | |
| THError("aborting"); | |
| } | |
| } | |
| /* | |
| * API-compatible with THRealTensor_conv2DRevger | |
| * 3D input, 3D kernel, 4D output | |
| * like rank1 update | |
| * A <- xx' + beta*A | |
| * for sr,sc=1 this is equivalent to xcorr2Dger, but otherwise it is useful for | |
| * calculating derivatives wrt a kernel that is applied with stride sr,sc != 1 | |
| */ | |
| THC_API void THCudaTensor_conv2DRevger(THCudaTensor *output, float beta, float alpha, | |
| THCudaTensor *input, THCudaTensor *kernel, | |
| long srow, long scol) | |
| { | |
| long nInputPlane, nInputRows, nInputCols; | |
| long nKernelPlane, nKernelRows, nKernelCols; | |
| long nOutputRows, nOutputCols; | |
| THArgCheck(input->nDimension == 3 , 3, "input: 3D Tensor expected"); | |
| THArgCheck(kernel->nDimension == 3 , 4, "kernel: 3D Tensor expected"); | |
| THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); | |
| THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); | |
| input = THCudaTensor_newContiguous(input); | |
| kernel = THCudaTensor_newContiguous(kernel); | |
| nInputPlane = input->size[0]; | |
| nInputRows = input->size[1]; | |
| nInputCols = input->size[2]; | |
| nKernelPlane = kernel->size[0]; | |
| nKernelRows = kernel->size[1]; | |
| nKernelCols = kernel->size[2]; | |
| THArgCheck(nInputRows >= nKernelRows && nInputCols >= nKernelCols , 2, | |
| "conv2DRevger : Input image is smaller than kernel"); | |
| nOutputRows = nInputRows - (nKernelRows - 1) * srow; | |
| nOutputCols = nInputCols - (nKernelCols - 1) * scol; | |
| long nelem = THCudaTensor_nElement(output); | |
| THCudaTensor_resize4d(output, nKernelPlane, nInputPlane, nOutputRows, nOutputCols); | |
| if (nelem == 0 || beta == 0 || nelem != THCudaTensor_nElement(output)) { | |
| THCudaTensor_zero(output); | |
| } else if (beta != 1) { | |
| THCudaTensor_mul(output, output, beta); | |
| } | |
| float *input_data = THCudaTensor_data(input); | |
| float *kernel_data = THCudaTensor_data(kernel); | |
| float *output_data = THCudaTensor_data(output); | |
| // auto compute nb of blocks and threads | |
| dim3 blocks(nKernelPlane, nInputPlane); | |
| dim3 threads(128/nOutputRows, nOutputRows); | |
| // compute rev conv | |
| conv2genericrev <<<blocks, threads>>> (input_data, kernel_data, output_data, | |
| nInputPlane, nInputRows, nInputCols, | |
| nKernelPlane, nKernelRows, nKernelCols, | |
| alpha, srow, scol); | |
| // clean | |
| THCudaTensor_free(input); | |
| THCudaTensor_free(kernel); | |
| // check for errors | |
| cudaError_t err = cudaGetLastError(); | |
| if (err != cudaSuccess) { | |
| printf("error in conv2DRevger: %s\n", cudaGetErrorString(err)); | |
| THError("aborting"); | |
| } | |
| } | |
| /* | |
| * API-compatible with THRealTensor_conv2DRevgerm | |
| * 4D input, 4D kernel, 4D output | |
| * conv2DRevgerm is doing the same thing as conv2DRevger, but with batch inputs | |
| */ | |
| THC_API void THCudaTensor_conv2DRevgerm(THCudaTensor *output, float beta, float alpha, | |
| THCudaTensor *input, THCudaTensor *kernel, | |
| long srow, long scol) | |
| { | |
| long nInputPlane, nInputRows, nInputCols; | |
| long nKernelPlane, nKernelRows, nKernelCols; | |
| long nOutputRows, nOutputCols; | |
| long nbatch; | |
| THArgCheck(input->nDimension == 4 , 3, "input: 3D Tensor expected"); | |
| THArgCheck(kernel->nDimension == 4 , 4, "kernel: 3D Tensor expected"); | |
| THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); | |
| THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); | |
| input = THCudaTensor_newContiguous(input); | |
| kernel = THCudaTensor_newContiguous(kernel); | |
| nbatch = input->size[0]; | |
| nInputPlane = input->size[1]; | |
| nInputRows = input->size[2]; | |
| nInputCols = input->size[3]; | |
| nKernelPlane = kernel->size[1]; | |
| nKernelRows = kernel->size[2]; | |
| nKernelCols = kernel->size[3]; | |
| THArgCheck(nInputRows >= nKernelRows && nInputCols >= nKernelCols , 2, | |
| "conv2DRevger : Input image is smaller than kernel"); | |
| nOutputRows = nInputRows - (nKernelRows - 1) * srow; | |
| nOutputCols = nInputCols - (nKernelCols - 1) * scol; | |
| long nelem = THCudaTensor_nElement(output); | |
| THCudaTensor_resize4d(output, nKernelPlane, nInputPlane, nOutputRows, nOutputCols); | |
| if (nelem == 0 || beta == 0 || nelem != THCudaTensor_nElement(output)) { | |
| THCudaTensor_zero(output); | |
| } else if (beta != 1) { | |
| THCudaTensor_mul(output, output, beta); | |
| } | |
| float *input_data = THCudaTensor_data(input); | |
| float *kernel_data = THCudaTensor_data(kernel); | |
| float *output_data = THCudaTensor_data(output); | |
| // kernel is called multiple times | |
| // (the arbitrary split below is just here to make sure we dont go over 256 threads) | |
| for (int sl=0; sl<nbatch; sl+=6) { | |
| // auto compute nb of blocks and threads | |
| dim3 blocks(nKernelPlane, nInputPlane); | |
| int subbatch = 6; | |
| if (sl+subbatch > nbatch) subbatch = nbatch - sl; | |
| int cst = 256 / (subbatch * nOutputRows); | |
| dim3 threads(cst, nOutputRows, subbatch); | |
| // compute rev conv | |
| conv2genericrev <<<blocks, threads>>> (input_data + input->stride[0]*sl, | |
| kernel_data + kernel->stride[0]*sl, | |
| output_data, | |
| nInputPlane, nInputRows, nInputCols, | |
| nKernelPlane, nKernelRows, nKernelCols, | |
| alpha, srow, scol); | |
| } | |
| // clean | |
| THCudaTensor_free(input); | |
| THCudaTensor_free(kernel); | |
| // check for errors | |
| cudaError_t err = cudaGetLastError(); | |
| if (err != cudaSuccess) { | |
| printf("error in conv2DRevger: %s\n", cudaGetErrorString(err)); | |
| THError("aborting"); | |
| } | |
| } | |
| /////////////////////////////////// | |
| ///// ConvolutionMap | |
| /* | |
| * Description: | |
| * base conv2D routine: 3D input, 3D output, 4D kernel | |
| * | |
| * - all chunks of data should be contiguous | |
| * - the swapkernel flag can be used to generate a conv2 instead of xcorr2 | |
| * - the templated kernel size is useful to generate code that's 2x faster | |
| * but can be set to 0 to allow arbitrary kernel sizes | |
| * ---- the table should have the first dim with the outputs, each output | |
| * ---- should have a fanin set of inputs contiguously | |
| */ | |
| template <bool swapkernel, int T_kernel_h, int T_kernel_w> | |
| __global__ void conv2mapgeneric(float *input, float *kernel, float *output, | |
| int input_n, int input_h, int input_w, | |
| int kernel_n, int kernel_h, int kernel_w, | |
| int stride_w, int stride_h, | |
| float *table, int fanin) | |
| { | |
| // output dimensions | |
| int output_h = (input_h - kernel_h) / stride_h + 1; | |
| int output_w = (input_w - kernel_w) / stride_w + 1; | |
| // xcorr or conv | |
| int koffset = swapkernel ? kernel_w*kernel_h-1 : 0; | |
| // nb outputs | |
| // int output_n = kernel_n / fanin; | |
| // generate offsets according to block/thread ids | |
| int xx_start = threadIdx.x; | |
| int xx_end = output_w; | |
| int xx_step = blockDim.x; | |
| int yy_start = blockDim.y*blockIdx.y + threadIdx.y; | |
| int yy_end = output_h; | |
| int yy_step = blockDim.y*gridDim.y; | |
| int oo_start = blockIdx.x; | |
| int oo_end = oo_start+1; | |
| int table_start = blockIdx.x * (fanin * 2); | |
| int table_end = table_start + (fanin * 2); | |
| // nb threads, unique thread id | |
| int tid = blockDim.x*blockDim.y*threadIdx.z | |
| + blockDim.x * threadIdx.y + threadIdx.x; | |
| int nthreads = blockDim.x * blockDim.y * blockDim.z; | |
| // iterators | |
| int oo, ii, xx, yy, kx, ky, kk; | |
| // do the kernels fit in shared mem ? | |
| if (kernel_w*kernel_h*kernel_n <= CUDA_SHARED_MEM_SIZE) { | |
| // put the kernel in shared memory | |
| __shared__ float shared_kernel[CUDA_SHARED_MEM_SIZE]; | |
| // first thread of each block does the copy | |
| for (kk = tid; kk < kernel_w*kernel_h*kernel_n; kk += nthreads) { | |
| shared_kernel[kk] = kernel[kk]; | |
| } | |
| __syncthreads(); | |
| // templated kernel size | |
| if ((T_kernel_w > 0) && (T_kernel_h > 0)) { | |
| // unrolled convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for (ii = table_start; ii < table_end; ii = ii + 2) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dimensions... (between input image and the mask) | |
| float *input_p = input + ((long)table[ii]-1)*input_h*input_w | |
| + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w + xx; | |
| float *kernel_p = shared_kernel | |
| + ((long)table[ii + 1]-1) *kernel_w*kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| #pragma unroll | |
| for(ky = 0; ky < T_kernel_h; ky++) { | |
| #pragma unroll | |
| for(kx = 0; kx < T_kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| #pragma unroll | |
| for(ky = 0; ky < T_kernel_h; ky++) { | |
| #pragma unroll | |
| for(kx = 0; kx < T_kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } else { | |
| // default convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for (ii = table_start; ii < table_end; ii++) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dims (between input image and the mask) | |
| float *input_p = input + ((long)table[ii]-1)*input_h*input_w | |
| + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w | |
| + xx; | |
| float *kernel_p = shared_kernel | |
| + ((long)table[ii + 1]-1) *kernel_w*kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else { // not enough shared mem for kernels, simply stream them | |
| // convolution loop | |
| for(oo = oo_start; oo < oo_end; oo++) { | |
| for (ii = table_start; ii < table_end; ii = ii + 2) { | |
| for(yy = yy_start; yy < yy_end; yy+=yy_step) { | |
| for(xx = xx_start; xx < xx_end; xx+=xx_step) { | |
| // Dot product in two dimensions... (between input image and the mask) | |
| float *input_p = input + ((long)table[ii]-1)*input_h*input_w | |
| + yy*stride_h*input_w + xx*stride_w; | |
| float *output_p = output + oo*output_h*output_w + yy*output_w + xx; | |
| float *kernel_p = kernel + ((long)table[ii + 1]-1) *kernel_w*kernel_h + koffset; | |
| float sum = 0; | |
| if (swapkernel) { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p--); | |
| } | |
| input_p += input_w; | |
| } | |
| } else { | |
| for(ky = 0; ky < kernel_h; ky++) { | |
| #pragma unroll 5 | |
| for(kx = 0; kx < kernel_w; kx++) { | |
| sum += input_p[kx]*(*kernel_p++); | |
| } | |
| input_p += input_w; | |
| } | |
| } | |
| *output_p += sum; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| /* | |
| * API-compatible with THRealTensor_conv2Dmv | |
| * 3D input, 4D kernel, 3D output | |
| * matrix vector product like: y <- Ax + beta*y | |
| */ | |
| THC_API void THCudaTensor_conv2Dmap(THCudaTensor *output, THCudaTensor *input, | |
| THCudaTensor *kernel, long stride_x, long stride_y, | |
| THCudaTensor *table, long fanin) | |
| { | |
| long nInputPlane, nInputRows, nInputCols; | |
| long nKernelRows, nKernelCols; | |
| long nOutputPlane, nOutputRows, nOutputCols; | |
| THArgCheck(kernel->nDimension == 3 , 4, "kernel: 3D Tensor expected"); | |
| THArgCheck(stride_x >= 1, 5, "Stride should be a positive integer"); | |
| THArgCheck(stride_y >= 1, 6, "Stride should be a positive integer"); | |
| input = THCudaTensor_newContiguous(input); | |
| kernel = THCudaTensor_newContiguous(kernel); | |
| table = THCudaTensor_newContiguous(table); | |
| nInputPlane = input->size[0]; | |
| nInputRows = input->size[1]; | |
| nInputCols = input->size[2]; | |
| nKernelRows = kernel->size[1]; | |
| nKernelCols = kernel->size[2]; | |
| nOutputPlane = kernel->size[0] / fanin; | |
| // THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes"); | |
| THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols), 2, | |
| "conv2Dmap : Input image is smaller than kernel"); | |
| // output dims | |
| nOutputRows = (nInputRows - nKernelRows) / stride_y + 1; | |
| nOutputCols = (nInputCols - nKernelCols) / stride_x + 1; | |
| // long nelem = THCudaTensor_nElement(output); | |
| THCudaTensor_resize3d(output, nOutputPlane, nOutputRows, nOutputCols); | |
| float *input_data = THCudaTensor_data(input); | |
| float *kernel_data = THCudaTensor_data(kernel); | |
| float *output_data = THCudaTensor_data(output); | |
| float *table_data = THCudaTensor_data(table); | |
| // set the number of blocks and threads | |
| int nthreads_x = 32; | |
| int nthreads_y = 8; | |
| int block_height = (int)(16L / nOutputPlane); | |
| if (block_height < 1) | |
| block_height = 1; | |
| dim3 blocks(nOutputPlane,block_height); | |
| dim3 threads(nthreads_x,nthreads_y); | |
| #define GENERIC_MAP_KERNEL(dim) \ | |
| conv2mapgeneric <false, (dim), (dim)> <<<blocks, threads>>> ( \ | |
| input_data, kernel_data, output_data, nInputPlane, nInputRows, \ | |
| nInputCols, nOutputPlane*fanin, nKernelRows, nKernelCols, \ | |
| stride_x, stride_y, table_data, fanin); | |
| FOR_KERNEL_SPECIALIZED_DIMENSION(nKernelCols, nKernelRows, GENERIC_MAP_KERNEL); | |
| #undef GENERIC_MAP_KERNEL | |
| // clean | |
| THCudaTensor_free(input); | |
| THCudaTensor_free(kernel); | |
| THCudaTensor_free(table); | |
| // check for errors | |
| cudaError_t err = cudaGetLastError(); | |
| if (err != cudaSuccess) { | |
| printf("error in conv2Dmap: %s\n", cudaGetErrorString(err)); | |
| THError("aborting"); | |
| } | |
| } | |
| #undef FOR_KERNEL_SPECIALIZED_DIMENSION |