Use indices for SpatialAdaptiveMaxPooling indices.
diff --git a/SpatialAdaptiveMaxPooling.cu b/SpatialAdaptiveMaxPooling.cu
index 5dd8659..b67203b 100644
--- a/SpatialAdaptiveMaxPooling.cu
+++ b/SpatialAdaptiveMaxPooling.cu
@@ -8,7 +8,7 @@
  *    this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
  *    4D input, 4D output, 4D argmax x and y
  */
-__global__ void adaptivemaxpool(float *input, float *output, float *indices_x, float *indices_y,
+__global__ void adaptivemaxpool(float *input, float *output, long *indices_x, long *indices_y,
                         int input_n, int input_h, int input_w,
                         int output_h, int output_w,
                         int strideh, int stridew,
@@ -29,7 +29,6 @@
   int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
   int yy_end = output_h;
   const int yy_step = blockDim.y*gridDim.y;
-
   // select input/output plane
   output = output + o*output_w*output_h;
   input = input + i*strided;
@@ -52,8 +51,8 @@
       // Compute the mean of the input image...
       float *ptr_input = input + y_start*strideh + x_start*stridew;
       float *ptr_output = output + yy*output_w + xx;
-      float *ptr_ind_x = indices_x + yy*output_w + xx;
-      float *ptr_ind_y = indices_y + yy*output_w + xx;
+      long *ptr_ind_x = indices_x + yy*output_w + xx;
+      long *ptr_ind_y = indices_y + yy*output_w + xx;
       int argmax_x = -1;
       int argmax_y = -1;
       float max = -FLT_MAX;
@@ -81,7 +80,7 @@
  * Description:
  *    this function computes the gradInput from weight and gradOutput
  */
-__global__ void adaptivemaxgradinput(float *gradInput, float *gradOutput, float *indices_x, float *indices_y,
+__global__ void adaptivemaxgradinput(float *gradInput, float *gradOutput, long *indices_x, long *indices_y,
                              int input_n, int input_h, int input_w,
                              int output_h, int output_w)
 {
@@ -118,8 +117,8 @@
 
       float *ptr_gradInput = gradInput + y_start*input_w + x_start;
       float *ptr_gradOutput = gradOutput + yy*output_w + xx;
-      float *ptr_ind_x = indices_x + yy*output_w + xx;
-      float *ptr_ind_y = indices_y + yy*output_w + xx;
+      long *ptr_ind_x = indices_x + yy*output_w + xx;
+      long *ptr_ind_y = indices_y + yy*output_w + xx;
       float z = *ptr_gradOutput;
 
       int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE;
@@ -136,7 +135,7 @@
  *    when kH != dH or kW != dW (uses atomic add)
  */
 __global__ void atomicadaptivemaxgradinput(
-  float *gradInput, float *gradOutput, float *indices_x, float *indices_y,
+  float *gradInput, float *gradOutput, long *indices_x, long *indices_y,
   int input_n, int input_h, int input_w, int output_h, int output_w
 )
 {
@@ -172,8 +171,8 @@
 
       float *ptr_gradInput = gradInput + y_start*input_w + x_start;
       float *ptr_gradOutput = gradOutput + yy*output_w + xx;
-      float *ptr_ind_x = indices_x + yy*output_w + xx;
-      float *ptr_ind_y = indices_y + yy*output_w + xx;
+      long *ptr_ind_x = indices_x + yy*output_w + xx;
+      long *ptr_ind_y = indices_y + yy*output_w + xx;
       float z = *ptr_gradOutput;
 
       int argmax_x = (*ptr_ind_x) - TH_INDEX_BASE;
@@ -185,11 +184,11 @@
   }
 }
 
-void THNN_CudaSpatialAdaptiveMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaTensor *indices, int nOutputCols, int nOutputRows)
+void THNN_CudaSpatialAdaptiveMaxPooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, THCudaLongTensor *indices, int nOutputCols, int nOutputRows)
 {
   THCUNN_assertSameGPU(state, 3, input, output, indices);
 
-  float *indices_data;
+  long *indices_data;
   float *output_data;
   float *input_data;
 
@@ -207,9 +206,9 @@
     input_data = THCudaTensor_data(state, input);
 
     THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
-    THCudaTensor_resize4d(state, indices, 2, nInputPlane, nOutputRows, nOutputCols);
+    THCudaLongTensor_resize4d(state, indices, 2, nInputPlane, nOutputRows, nOutputCols);
 
-    indices_data = THCudaTensor_data(state, indices);
+    indices_data = THCudaLongTensor_data(state, indices);
     output_data = THCudaTensor_data(state, output);
 
     // cuda blocks & threads:
@@ -239,9 +238,9 @@
     input_data = THCudaTensor_data(state, input);
 
     THCudaTensor_resize4d(state, output, nbatch, nInputPlane, nOutputRows, nOutputCols);
-    THCudaTensor_resize5d(state, indices, 2, nbatch, nInputPlane, nOutputRows, nOutputCols);
+    THCudaLongTensor_resize5d(state, indices, 2, nbatch, nInputPlane, nOutputRows, nOutputCols);
 
-    indices_data = THCudaTensor_data(state, indices);
+    indices_data = THCudaLongTensor_data(state, indices);
     output_data = THCudaTensor_data(state, output);
 
     // cuda blocks & threads:
@@ -261,13 +260,13 @@
   }
 }
 
-void THNN_CudaSpatialAdaptiveMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaTensor *indices)
+void THNN_CudaSpatialAdaptiveMaxPooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, THCudaLongTensor *indices)
 {
   bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests
 
   THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput);
 
-  float *indices_data;
+  long *indices_data;
   float *gradInput_data;
   float *gradOutput_data;
 
@@ -285,7 +284,7 @@
     THCudaTensor_resizeAs(state, gradInput, input);
     THCudaTensor_zero(state, gradInput);
 
-    indices_data = THCudaTensor_data(state, indices);
+    indices_data = THCudaLongTensor_data(state, indices);
     gradOutput_data = THCudaTensor_data(state, gradOutput);
     gradInput_data = THCudaTensor_data(state, gradInput);
 
@@ -323,7 +322,7 @@
     THCudaTensor_resizeAs(state, gradInput, input);
     THCudaTensor_zero(state, gradInput);
 
-    indices_data = THCudaTensor_data(state, indices);
+    indices_data = THCudaLongTensor_data(state, indices);
     gradOutput_data = THCudaTensor_data(state, gradOutput);
     gradInput_data = THCudaTensor_data(state, gradInput);
 
diff --git a/THCUNN.h b/THCUNN.h
index 553fd72..ba7cdee 100644
--- a/THCUNN.h
+++ b/THCUNN.h
@@ -476,7 +476,7 @@
           THCState *state,
           THCudaTensor *input,
           THCudaTensor *output,
-          THCudaTensor *indices,
+          THCudaLongTensor *indices,
           int nOutputCols,
           int nOutputRows);
 TH_API void THNN_CudaSpatialAdaptiveMaxPooling_updateGradInput(
@@ -484,7 +484,7 @@
           THCudaTensor *input,
           THCudaTensor *gradOutput,
           THCudaTensor *gradInput,
-          THCudaTensor *indices);
+          THCudaLongTensor *indices);
 
 TH_API void THNN_CudaSpatialAveragePooling_updateOutput(
           THCState *state,