rename spatial version
diff --git a/torch/lib/THCUNN/SpatialAdaptiveAveragePooling.cu b/torch/lib/THCUNN/SpatialAdaptiveAveragePooling.cu
index b1e5e5c..45b5654 100644
--- a/torch/lib/THCUNN/SpatialAdaptiveAveragePooling.cu
+++ b/torch/lib/THCUNN/SpatialAdaptiveAveragePooling.cu
@@ -18,54 +18,54 @@
  */
  template <typename T>
 __global__ void adaptiveaveragepool(T *input, T *output,
-                        int input_n, int input_h, int input_w,
-                        int output_h, int output_w,
-                        int strideh, int stridew,
-                        int strided)
+                        int sizeD, int isizeH, int isizeW,
+                        int osizeH, int osizeW,
+                        int istrideH, int istrideW,
+                        int istrideD)
 {
   // iterators
-  int xx, yy;
+  int ow, oh;
 
   // compute offsets based on thread/block ID
-  int o = blockIdx.x;
-  int i = o;
-  //int k = blockIdx.x % input_n;
+  int o_plane = blockIdx.x;
+  int i_plane = o_plane;
+  //int k = blockIdx.x % sizeD;
 
-  int xx_start = threadIdx.x;
-  int xx_end = output_w;
-  const int xx_step = blockDim.x;
+  int ostartW = threadIdx.x;
+  int oendW = osizeW;
+  const int ostepW = blockDim.x;
 
-  int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
-  int yy_end = output_h;
-  const int yy_step = blockDim.y*gridDim.y;
+  int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
+  int oendH = osizeH;
+  const int ostepH = blockDim.y*gridDim.y;
   // select input/output plane
-  output = output + o*output_w*output_h;
-  input = input + i*strided;
+  output = output + o_plane*osizeW*osizeH;
+  input = input + i_plane*istrideD;
 
   // For all output pixels...
-  for(yy = yy_start; yy < yy_end; yy+=yy_step) {
+  for(oh = ostartH; oh < oendH; oh+=ostepH) {
 
-    int y_start = START_IND(yy, output_h, input_h);
-    int y_end   = END_IND(yy, output_h, input_h);
-    int kH = y_end-y_start;
+    int istartH = START_IND(oh, osizeH, isizeH);
+    int iendH   = END_IND(oh, osizeH, isizeH);
+    int kH = iendH-istartH;
 
-    for(xx = xx_start; xx < xx_end; xx+=xx_step) {
+    for(ow = ostartW; ow < oendW; ow+=ostepW) {
 
-      int x_start = START_IND(xx, output_w, input_w);
-      int x_end   = END_IND(xx, output_w, input_w);
-      int kW = x_end-x_start;
+      int istartW = START_IND(ow, osizeW, isizeW);
+      int iendW   = END_IND(ow, osizeW, isizeW);
+      int kW = iendW-istartW;
 
       // Compute the average pooling
-      T *ptr_input = input + y_start*strideh + x_start*stridew;
-      T *ptr_output = output + yy*output_w + xx;
+      T *ptr_input = input + istartH*istrideH + istartW*istrideW;
+      T *ptr_output = output + oh*osizeW + ow;
       T sum = ScalarConvert<int, T>::to(0);
-      int kx, ky;
-      for(ky = 0; ky < kH; ++ky) {
-        for(kx = 0; kx < kW; ++kx) {
-          T val = ptr_input[kx*stridew];
+      int iw, ih;
+      for(ih = 0; ih < kH; ++ih) {
+        for(iw = 0; iw < kW; ++iw) {
+          T val = ptr_input[iw*istrideW];
           sum += val;
         }
-        ptr_input += strideh; // next input line
+        ptr_input += istrideH; // next input line
       }
       // Update output
       *ptr_output = sum / kH / kW;
@@ -80,54 +80,54 @@
  template <typename T>
 __global__ void adaptiveaveragegradinput(
   T *gradInput, T *gradOutput,
-  int input_n, int input_h, int input_w, int output_h, int output_w
+  int sizeD, int isizeH, int isizeW, int osizeH, int osizeW
 )
 {
   // iterators
-  int x, y;
+  int iw, ih;
 
   // compute offsets based on thread/block ID
-  int o = blockIdx.x;
-  int i = o;
+  int o_plane = blockIdx.x;
+  int i_plane = o_plane;
 
-  int x_start = threadIdx.x;
-  int x_end = input_w;
-  int x_step = blockDim.x;
+  int istartW = threadIdx.x;
+  int iendW = isizeW;
+  int istepW = blockDim.x;
 
-  int y_start = blockDim.y*blockIdx.y + threadIdx.y;
-  int y_end = input_h;
-  int y_step = blockDim.y*gridDim.y;
+  int istartH = blockDim.y*blockIdx.y + threadIdx.y;
+  int iendH = isizeH;
+  int istepH = blockDim.y*gridDim.y;
 
   // select input/output plane
-  gradOutput = gradOutput + o*output_w*output_h;
-  gradInput = gradInput + i*input_w*input_h;
+  gradOutput = gradOutput + o_plane*osizeW*osizeH;
+  gradInput = gradInput + i_plane*isizeW*isizeH;
 
   // compute gradInput
-  for(y = y_start; y < y_end; y+=y_step) {
+  for(ih = istartH; ih < iendH; ih+=istepH) {
 
-    int yy_start = START_IND(y, input_h, output_h);
-    int yy_end   = END_IND(y, input_h, output_h);
-    int kH = yy_end-yy_start;
+    int ostartH = START_IND(ih, isizeH, osizeH);
+    int oendH   = END_IND(ih, isizeH, osizeH);
+    int kH = oendH-ostartH;
 
-    for(x = x_start; x < x_end; x+=x_step) {
+    for(iw = istartW; iw < iendW; iw+=istepW) {
 
-      int xx_start = START_IND(x, input_w, output_w);
-      int xx_end   = END_IND(x, input_w, output_w);
-      int kW = xx_end-xx_start;
+      int ostartW = START_IND(iw, isizeW, osizeW);
+      int oendW   = END_IND(iw, isizeW, osizeW);
+      int kW = oendW-ostartW;
 
       // Compute the gradients
-      T *ptr_gradInput = gradInput + y*input_w + x;
-      T *ptr_gradOutput = gradOutput + yy_start*output_w + xx_start;
-      
-      int kx, ky;
-      for(ky = 0; ky < kH; ++ky) {
-        int yy = yy_start + ky;
-        int kkH = START_IND(yy, output_h, input_h) - END_IND(yy, output_h, input_h);
-        for(kx = 0; kx < kW; ++kx) {
-          int xx = xx_start + kx;
-          int kkW = START_IND(xx, output_w, input_w) - END_IND(xx, output_w, input_w);
-          T z = ptr_gradOutput[kx + ky*output_w] / kkW / kkH;
-          *ptr_gradInput += z;
+      T *ptr_gradInput = gradInput + ih*isizeW + iw;
+      T *ptr_gradOutput = gradOutput + ostartH*osizeW + ostartW;
+
+      int ow, oh;
+      for(oh = 0; oh < kH; ++oh) {
+        int orealH = ostartH + oh;
+        int kkH = START_IND(orealH, osizeH, isizeH) - END_IND(orealH, osizeH, isizeH);
+        for(ow = 0; ow < kW; ++ow) {
+          int orealW = ostartW + ow;
+          int kkW = START_IND(orealW, osizeW, isizeW) - END_IND(orealW, osizeW, isizeW);
+          T grad_delta = ptr_gradOutput[ow + oh*osizeW] / kkW / kkH;
+          *ptr_gradInput += grad_delta;
         }
       }
     }
@@ -142,50 +142,50 @@
  template <typename T>
 __global__ void atomicadaptiveaveragegradinput(
   T *gradInput, T *gradOutput,
-  int input_n, int input_h, int input_w, int output_h, int output_w
+  int sizeD, int isizeH, int isizeW, int osizeH, int osizeW
 )
 {
   // iterators
-  int xx, yy;
+  int ow, oh;
 
   // compute offsets based on thread/block ID
-  int o = blockIdx.x;
-  int i = o;
+  int o_plane = blockIdx.x;
+  int i_plane = o_plane;
 
-  int xx_start = threadIdx.x;
-  int xx_end = output_w;
-  int xx_step = blockDim.x;
+  int ostartW = threadIdx.x;
+  int oendW = osizeW;
+  int ostepW = blockDim.x;
 
-  int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
-  int yy_end = output_h;
-  int yy_step = blockDim.y*gridDim.y;
+  int ostartH = blockDim.y*blockIdx.y + threadIdx.y;
+  int oendH = osizeH;
+  int ostepH = blockDim.y*gridDim.y;
 
   // select input/output plane
-  gradOutput = gradOutput + o*output_w*output_h;
-  gradInput = gradInput + i*input_w*input_h;
+  gradOutput = gradOutput + o_plane*osizeW*osizeH;
+  gradInput = gradInput + i_plane*isizeW*isizeH;
 
   // compute gradInput
-  for(yy = yy_start; yy < yy_end; yy+=yy_step) {
+  for(oh = ostartH; oh < oendH; oh+=ostepH) {
 
-    int y_start = START_IND(yy, output_h, input_h);
-    int y_end   = END_IND(yy, output_h, input_h);
-    int kH = y_end-y_start;
+    int istartH = START_IND(oh, osizeH, isizeH);
+    int iendH   = END_IND(oh, osizeH, isizeH);
+    int kH = iendH-istartH;
 
-    for(xx = xx_start; xx < xx_end; xx+=xx_step) {
+    for(ow = ostartW; ow < oendW; ow+=ostepW) {
 
-      int x_start = START_IND(xx, output_w, input_w);
-      int x_end   = END_IND(xx, output_w, input_w);
-      int kW = x_end-x_start;
+      int istartW = START_IND(ow, osizeW, isizeW);
+      int iendW   = END_IND(ow, osizeW, isizeW);
+      int kW = iendW-istartW;
 
       // Compute the gradients
-      T *ptr_gradInput = gradInput + y_start*input_w + x_start;
-      T *ptr_gradOutput = gradOutput + yy*output_w + xx;
-      T z = *ptr_gradOutput / kW / kH;
-      int kx, ky;
-      for(ky = 0; ky < kH; ++ky) {
-        for(kx = 0; kx < kW; ++kx) {
+      T *ptr_gradInput = gradInput + istartH*isizeW + istartW;
+      T *ptr_gradOutput = gradOutput + oh*osizeW + ow;
+      T grad_delta = *ptr_gradOutput / kW / kH;
+      int iw, ih;
+      for(ih = 0; ih < kH; ++ih) {
+        for(iw = 0; iw < kW; ++iw) {
           // atomic add since different threads could update same variable
-          atomicAdd(&(ptr_gradInput[kx + ky*input_w]), z);
+          atomicAdd(&(ptr_gradInput[iw + ih*isizeW]), grad_delta);
         }
       }
     }
diff --git a/torch/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu b/torch/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu
index 2b472a6..996a9e9 100644
--- a/torch/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu
+++ b/torch/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu
@@ -8,8 +8,8 @@
            THCState *state,
            THCTensor *input,
            THCTensor *output,
-           int nOutputCols,
-           int nOutputRows)
+           int osizeW,
+           int osizeH)
 {
   THCUNN_assertSameGPU(state, 2, input, output);
 
@@ -20,59 +20,59 @@
                   "3D or 4D (batch mode) tensor expected for input, but got: %s");
 
   if (input->nDimension == 3) {
-    int64_t nInputCols = input->size[2];
-    int64_t nInputRows = input->size[1];
-    int64_t nInputPlane = input->size[0];
+    int64_t isizeW = input->size[2];
+    int64_t isizeH = input->size[1];
+    int64_t sizeD = input->size[0];
 
-    int64_t istride_d = input->stride[0];
-    int64_t istride_h = input->stride[1];
-    int64_t istride_w = input->stride[2];
+    int64_t istrideD = input->stride[0];
+    int64_t istrideH = input->stride[1];
+    int64_t istrideW = input->stride[2];
 
     input_data = THCTensor_(data)(state, input);
 
-    THCTensor_(resize3d)(state, output, nInputPlane, nOutputRows, nOutputCols);
+    THCTensor_(resize3d)(state, output, sizeD, osizeH, osizeW);
 
     output_data = THCTensor_(data)(state, output);
 
     // cuda blocks & threads:
-    int yblocks = (int)(16L / nInputPlane);
-    yblocks = yblocks < 1 ? 1 : yblocks;
-    dim3 blocks(nInputPlane,yblocks);
-    dim3 threads(32,8);
+    int blocksH = (int)(16L / sizeD);
+    blocksH = blocksH < 1 ? 1 : blocksH;
+    dim3 blocks(sizeD, blocksH);
+    dim3 threads(32, 8);
 
     // run averagepool kernel
     adaptiveaveragepool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
-                                   nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
-                                   istride_h, istride_w, istride_d);
+                                   sizeD, isizeH, isizeW, osizeH, osizeW,
+                                   istrideH, istrideW, istrideD);
     THCudaCheck(cudaGetLastError());
 
   } else {
     input = THCTensor_(newContiguous)(state, input);
-    int64_t nInputCols = input->size[3];
-    int64_t nInputRows = input->size[2];
-    int64_t nInputPlane = input->size[1];
-    int64_t nbatch = input->size[0];
+    int64_t isizeW = input->size[3];
+    int64_t isizeH = input->size[2];
+    int64_t sizeD = input->size[1];
+    int64_t sizeB = input->size[0];
 
-    int64_t istride_d = input->stride[1];
-    int64_t istride_h = input->stride[2];
-    int64_t istride_w = input->stride[3];
+    int64_t istrideD = input->stride[1];
+    int64_t istrideH = input->stride[2];
+    int64_t istrideW = input->stride[3];
 
     input_data = THCTensor_(data)(state, input);
 
-    THCTensor_(resize4d)(state, output, nbatch, nInputPlane, nOutputRows, nOutputCols);
+    THCTensor_(resize4d)(state, output, sizeB, sizeD, osizeH, osizeW);
 
     output_data = THCTensor_(data)(state, output);
 
     // cuda blocks & threads:
-    int yblocks = (int)(16L / nInputPlane);
-    yblocks = yblocks < 1 ? 1 : yblocks;
-    dim3 blocks(nInputPlane*nbatch,yblocks);
-    dim3 threads(32,8);
+    int blocksH = (int)(16L / sizeD);
+    blocksH = blocksH < 1 ? 1 : blocksH;
+    dim3 blocks(sizeD * sizeB, blocksH);
+    dim3 threads(32, 8);
 
     // run averagepool kernel
     adaptiveaveragepool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data,
-                                   nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
-                                   istride_h, istride_w, istride_d);
+                                   sizeD, isizeH, isizeW, osizeH, osizeW,
+                                   istrideH, istrideW, istrideD);
     THCudaCheck(cudaGetLastError());
     // clean
     THCTensor_(free)(state, input);
@@ -95,13 +95,13 @@
   gradOutput = THCTensor_(newContiguous)(state, gradOutput);
 
   if (input->nDimension == 3) {
-    int64_t nInputCols = input->size[2];
-    int64_t nInputRows = input->size[1];
-    int64_t nInputPlane = input->size[0];
-    int64_t nOutputCols = gradOutput->size[2];
-    int64_t nOutputRows = gradOutput->size[1];
+    int64_t isizeW = input->size[2];
+    int64_t isizeH = input->size[1];
+    int64_t sizeD = input->size[0];
+    int64_t osizeW = gradOutput->size[2];
+    int64_t osizeH = gradOutput->size[1];
 
-    //bool atomic = (nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0);
+    //bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
 
     THCTensor_(resizeAs)(state, gradInput, input);
     THCTensor_(zero)(state, gradInput);
@@ -110,33 +110,33 @@
     gradInput_data = THCTensor_(data)(state, gradInput);
 
     // cuda blocks & threads:
-    int yblocks = (int)(16L / nInputPlane);
-    yblocks = yblocks < 1 ? 1 : yblocks;
-    dim3 blocks(nInputPlane,yblocks);
-    dim3 threads(32,8);
+    int blocksH = (int)(16L / sizeD);
+    blocksH = blocksH < 1 ? 1 : blocksH;
+    dim3 blocks(sizeD, blocksH);
+    dim3 threads(32, 8);
 
     if(atomic)
     {
       // run updateGradInput kernel, accumulate gradients atomically
       atomicadaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
-                                          nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
+                                          sizeD, isizeH, isizeW, osizeH, osizeW);
     }
     else
     {
       // run updateGradInput kernel
       adaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
-                                          nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
+                                          sizeD, isizeH, isizeW, osizeH, osizeW);
     }
     THCudaCheck(cudaGetLastError());
   } else {
-    int64_t nInputCols = input->size[3];
-    int64_t nInputRows = input->size[2];
-    int64_t nInputPlane = input->size[1];
-    int64_t nbatch = input->size[0];
-    int64_t nOutputCols = gradOutput->size[3];
-    int64_t nOutputRows = gradOutput->size[2];
+    int64_t isizeW = input->size[3];
+    int64_t isizeH = input->size[2];
+    int64_t sizeD = input->size[1];
+    int64_t sizeB = input->size[0];
+    int64_t osizeW = gradOutput->size[3];
+    int64_t osizeH = gradOutput->size[2];
 
-    //bool atomic = //(nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0);
+    //bool atomic = //(isizeW%osizeW != 0) || (isizeH%osizeH != 0);
 
     THCTensor_(resizeAs)(state, gradInput, input);
     THCTensor_(zero)(state, gradInput);
@@ -145,22 +145,22 @@
     gradInput_data = THCTensor_(data)(state, gradInput);
 
     // cuda blocks & threads:
-    int yblocks = (int)(16L / nInputPlane);
-    yblocks = yblocks < 1 ? 1 : yblocks;
-    dim3 blocks(nInputPlane*nbatch,yblocks);
-    dim3 threads(32,8);
+    int blocksH = (int)(16L / sizeD);
+    blocksH = blocksH < 1 ? 1 : blocksH;
+    dim3 blocks(sizeD * sizeB, blocksH);
+    dim3 threads(32, 8);
 
     if(atomic)
     {
       // run updateGradInput kernel, accumulate gradients atomically
       atomicadaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
-                                          nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
+                                          sizeD, isizeH, isizeW, osizeH, osizeW);
     }
     else
     {
       // run updateGradInput kernel, accumulate gradients atomically
       adaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data,
-                                          nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols);
+                                          sizeD, isizeH, isizeW, osizeH, osizeW);
     }
     THCudaCheck(cudaGetLastError());
   }
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index 9963706..6aab9af 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -536,8 +536,8 @@
                   THCState *state,
                   THCTensor *input,
                   THCTensor *output,
-                  int nOutputCols,
-                  int nOutputRows);
+                  int osizeW,
+                  int osizeH);
 
 TH_API void THNN_(SpatialAdaptiveAveragePooling_updateGradInput)(
                   THCState *state,
diff --git a/torch/lib/THNN/generic/SpatialAdaptiveAveragePooling.c b/torch/lib/THNN/generic/SpatialAdaptiveAveragePooling.c
index 59c765c..dc5de3c 100644
--- a/torch/lib/THNN/generic/SpatialAdaptiveAveragePooling.c
+++ b/torch/lib/THNN/generic/SpatialAdaptiveAveragePooling.c
@@ -10,46 +10,46 @@
 static void THNN_(SpatialAdaptiveAveragePooling_updateOutput_frame)(
           real *input_p,
           real *output_p,
-          int64_t nslices,
-          int64_t iwidth,
-          int64_t iheight,
-          int64_t owidth,
-          int64_t oheight,
-          int64_t stridew,
-          int64_t strideh,
-          int64_t strided)
+          int64_t sizeD,
+          int64_t isizeW,
+          int64_t isizeH,
+          int64_t osizeW,
+          int64_t osizeH,
+          int64_t istrideW,
+          int64_t istrideH,
+          int64_t istrideD)
 {
-  int64_t k;
-#pragma omp parallel for private(k)
-  for (k = 0; k < nslices; k++)
+  int64_t d;
+#pragma omp parallel for private(d)
+  for (d = 0; d < sizeD; d++)
   {
     /* loop over output */
-    int64_t i, j;
-    for(i = 0; i < oheight; i++)
+    int64_t oh, ow;
+    for(oh = 0; oh < osizeH; oh++)
     {
-      int y_start = START_IND(i, oheight, iheight);
-      int y_end   = END_IND(i, oheight, iheight);
-      int kH = y_end-y_start;
+      int istartH = START_IND(oh, osizeH, isizeH);
+      int iendH   = END_IND(oh, osizeH, isizeH);
+      int kH = iendH-istartH;
 
-      for(j = 0; j < owidth; j++)
+      for(ow = 0; ow < osizeW; ow++)
       {
 
-        int x_start = START_IND(j, owidth, iwidth);
-        int x_end   = END_IND(j, owidth, iwidth);
-        int kW = x_end-x_start;
+        int istartW = START_IND(ow, osizeW, isizeW);
+        int iendW   = END_IND(ow, osizeW, isizeW);
+        int kW = iendW-istartW;
 
         /* local pointers */
-        real *ip = input_p   + k*strided + y_start*strideh + x_start*stridew;
-        real *op = output_p  + k*owidth*oheight + i*owidth + j;
+        real *ip = input_p   + d*istrideD + istartH*istrideH + istartW*istrideW;
+        real *op = output_p  + d*osizeW*osizeH + oh*osizeW + ow;
 
         /* compute local average: */
         real sum = 0;
-        int x,y;
-        for(y = 0; y < kH; y++)
+        int iw,ih;
+        for(ih = 0; ih < kH; ih++)
         {
-          for(x = 0; x < kW; x++)
+          for(iw = 0; iw < kW; iw++)
           {
-            real val = *(ip + y*strideh + x*stridew);
+            real val = *(ip + ih*istrideH + iw*istrideW);
             sum += val;
           }
         }
@@ -65,20 +65,20 @@
           THNNState *state,
           THTensor *input,
           THTensor *output,
-          int owidth,
-          int oheight)
+          int osizeW,
+          int osizeH)
 {
-  int dimw = 2;
-  int dimh = 1;
-  int64_t nbatch = 1;
-  int64_t nslices;
-  int64_t iheight;
-  int64_t iwidth;
+  int dimW = 2;
+  int dimH = 1;
+  int64_t sizeB = 1;
+  int64_t sizeD;
+  int64_t isizeH;
+  int64_t isizeW;
 
-  int64_t istride_d;
-  int64_t istride_h;
-  int64_t istride_w;
-  int64_t istride_b;
+  int64_t istrideD;
+  int64_t istrideH;
+  int64_t istrideW;
+  int64_t istrideB;
 
   real *input_data;
   real *output_data;
@@ -89,54 +89,54 @@
 
   if (input->nDimension == 4)
   {
-    istride_b = input->stride[0];
-    nbatch = input->size[0];
-    dimw++;
-    dimh++;
+    istrideB = input->stride[0];
+    sizeB = input->size[0];
+    dimW++;
+    dimH++;
   }
 
   /* sizes */
-  nslices = input->size[dimh-1];
-  iheight = input->size[dimh];
-  iwidth = input->size[dimw];
+  sizeD = input->size[dimH-1];
+  isizeH = input->size[dimH];
+  isizeW = input->size[dimW];
   /* strides */
-  istride_d = input->stride[dimh-1];
-  istride_h = input->stride[dimh];
-  istride_w = input->stride[dimw];
+  istrideD = input->stride[dimH-1];
+  istrideH = input->stride[dimH];
+  istrideW = input->stride[dimW];
 
   /* resize output */
   if (input->nDimension == 3)
   {
-    THTensor_(resize3d)(output, nslices, oheight, owidth);
+    THTensor_(resize3d)(output, sizeD, osizeH, osizeW);
 
     input_data = THTensor_(data)(input);
     output_data = THTensor_(data)(output);
 
     THNN_(SpatialAdaptiveAveragePooling_updateOutput_frame)(input_data, output_data,
-                                                      nslices,
-                                                      iwidth, iheight,
-                                                      owidth, oheight,
-                                                      istride_w,istride_h,
-                                                      istride_d);
+                                                      sizeD,
+                                                      isizeW, isizeH,
+                                                      osizeW, osizeH,
+                                                      istrideW, istrideH,
+                                                      istrideD);
   }
   else
   {
-    int64_t p;
+    int64_t b;
 
-    THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth);
+    THTensor_(resize4d)(output, sizeB, sizeD, osizeH, osizeW);
 
     input_data = THTensor_(data)(input);
     output_data = THTensor_(data)(output);
 
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
+#pragma omp parallel for private(b)
+    for (b = 0; b < sizeB; b++)
     {
-      THNN_(SpatialAdaptiveAveragePooling_updateOutput_frame)(input_data+p*istride_b, output_data+p*nslices*owidth*oheight,
-                                                        nslices,
-                                                        iwidth, iheight,
-                                                        owidth, oheight,
-                                                        istride_w,istride_h,
-                                                        istride_d);
+      THNN_(SpatialAdaptiveAveragePooling_updateOutput_frame)(input_data+b*istrideB, output_data+b*sizeD*osizeW*osizeH,
+                                                        sizeD,
+                                                        isizeW, isizeH,
+                                                        osizeW, osizeH,
+                                                        istrideW, istrideH,
+                                                        istrideD);
     }
   }
 }
@@ -144,41 +144,41 @@
 static void THNN_(SpatialAdaptiveAveragePooling_updateGradInput_frame)(
           real *gradInput_p,
           real *gradOutput_p,
-          int64_t nslices,
-          int64_t iwidth,
-          int64_t iheight,
-          int64_t owidth,
-          int64_t oheight)
+          int64_t sizeD,
+          int64_t isizeW,
+          int64_t isizeH,
+          int64_t osizeW,
+          int64_t osizeH)
 {
-  int64_t k;
-#pragma omp parallel for private(k)
-  for (k = 0; k < nslices; k++)
+  int64_t d;
+#pragma omp parallel for private(d)
+  for (d = 0; d < sizeD; d++)
   {
-    real *gradInput_p_k = gradInput_p + k*iwidth*iheight;
-    real *gradOutput_p_k = gradOutput_p + k*owidth*oheight;
+    real *gradInput_p_d = gradInput_p + d*isizeW*isizeH;
+    real *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH;
 
     /* calculate average */
-    int64_t i, j;
-    for(i = 0; i < oheight; i++)
+    int64_t oh, ow;
+    for(oh = 0; oh < osizeH; oh++)
     {
-      int y_start = START_IND(i, oheight, iheight);
-      int y_end   = END_IND(i, oheight, iheight);
-      int kH = y_end-y_start;
+      int istartH = START_IND(oh, osizeH, isizeH);
+      int iendH   = END_IND(oh, osizeH, isizeH);
+      int kH = iendH-istartH;
 
-      for(j = 0; j < owidth; j++)
+      for(ow = 0; ow < osizeW; ow++)
       {
 
-        int x_start = START_IND(j, owidth, iwidth);
-        int x_end   = END_IND(j, owidth, iwidth);
-        int kW = x_end-x_start;
+        int istartW = START_IND(ow, osizeW, isizeW);
+        int iendW   = END_IND(ow, osizeW, isizeW);
+        int kW = iendW-istartW;
 
-        int x,y;
-        for(y = y_start; y < y_end; y++)
+        int iw,ih;
+        for(ih = istartH; ih < iendH; ih++)
         {
-          for(x = x_start; x < x_end; x++)
+          for(iw = istartW; iw < iendW; iw++)
           {
             /* update gradient */
-            gradInput_p_k[y*iwidth + x] += gradOutput_p_k[i*owidth + j] / kW / kH;
+            gradInput_p_d[ih*isizeW + iw] += gradOutput_p_d[oh*osizeW + ow] / kW / kH;
           }
         }
       }
@@ -192,14 +192,14 @@
           THTensor *gradOutput,
           THTensor *gradInput)
 {
-  int dimw = 2;
-  int dimh = 1;
-  int64_t nbatch = 1;
-  int nslices;
-  int iheight;
-  int iwidth;
-  int oheight;
-  int owidth;
+  int dimW = 2;
+  int dimH = 1;
+  int64_t sizeB = 1;
+  int sizeD;
+  int isizeH;
+  int isizeW;
+  int osizeH;
+  int osizeW;
   real *gradInput_data;
   real *gradOutput_data;
 
@@ -211,17 +211,17 @@
   THTensor_(zero)(gradInput);
 
   if (input->nDimension == 4) {
-    nbatch = input->size[0];
-    dimw++;
-    dimh++;
+    sizeB = input->size[0];
+    dimW++;
+    dimH++;
   }
 
   /* sizes */
-  nslices = input->size[dimh-1];
-  iheight = input->size[dimh];
-  iwidth = input->size[dimw];
-  oheight = gradOutput->size[dimh];
-  owidth = gradOutput->size[dimw];
+  sizeD = input->size[dimH-1];
+  isizeH = input->size[dimH];
+  isizeW = input->size[dimW];
+  osizeH = gradOutput->size[dimH];
+  osizeW = gradOutput->size[dimW];
 
   /* get raw pointers */
   gradInput_data = THTensor_(data)(gradInput);
@@ -231,20 +231,20 @@
   if (input->nDimension == 3)
   {
     THNN_(SpatialAdaptiveAveragePooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
-                                                         nslices,
-                                                         iwidth, iheight,
-                                                         owidth, oheight);
+                                                         sizeD,
+                                                         isizeW, isizeH,
+                                                         osizeW, osizeH);
   }
   else
   {
-    int64_t p;
-#pragma omp parallel for private(p)
-    for (p = 0; p < nbatch; p++)
+    int64_t b;
+#pragma omp parallel for private(b)
+    for (b = 0; b < sizeB; b++)
     {
-      THNN_(SpatialAdaptiveAveragePooling_updateGradInput_frame)(gradInput_data+p*nslices*iwidth*iheight, gradOutput_data+p*nslices*owidth*oheight,
-                                                           nslices,
-                                                           iwidth, iheight,
-                                                           owidth, oheight);
+      THNN_(SpatialAdaptiveAveragePooling_updateGradInput_frame)(gradInput_data+b*sizeD*isizeW*isizeH, gradOutput_data+b*sizeD*osizeW*osizeH,
+                                                           sizeD,
+                                                           isizeW, isizeH,
+                                                           osizeW, osizeH);
     }
   }
 
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index d7ee0a1..5a7b306 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -936,7 +936,7 @@
           THNNState *state,
           THTensor *input,
           THTensor *output,
-          int owidth, int oheight);
+          int osizeW, int osizeH);
 TH_API void THNN_(SpatialAdaptiveAveragePooling_updateGradInput)(
           THNNState *state,
           THTensor *input,