Improve autograd memory usage (#859)

diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 8c0a52d..b66d91c 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -486,6 +486,12 @@
       if (self->cdata.requires_grad || self->cdata.is_stochastic) {
         _save_variables(self, t2var);
         _mark_non_differentiable(self, t2var);
+      } else {
+        // Remove unnecessary attributes
+        Py_XDECREF(self->to_save);
+        self->to_save = NULL;
+        Py_XDECREF(self->non_differentiable);
+        self->non_differentiable = NULL;
       }
     }
 
diff --git a/torch/nn/_functions/batchnorm.py b/torch/nn/_functions/batchnorm.py
index 46da184..db8f7a0 100644
--- a/torch/nn/_functions/batchnorm.py
+++ b/torch/nn/_functions/batchnorm.py
@@ -25,22 +25,26 @@
 
         # temporary buffers used in forward and backward
         num_features = input.size(1)
-        self._save_mean = input.new(num_features)
-        self._save_std = input.new(num_features)
+        _save_mean = input.new(num_features)
+        _save_std = input.new(num_features)
 
         output = input.new(input.size())
 
         if self.use_cudnn:
             torch._C._cudnn_batch_norm_forward(
                 input, output, weight, bias,
-                self.running_mean, self.running_var, self._save_mean,
-                self._save_std, self.training, self.momentum, self.eps)
+                self.running_mean, self.running_var, _save_mean,
+                _save_std, self.training, self.momentum, self.eps)
         else:
             backend = type2backend[type(input)]
             backend.BatchNormalization_updateOutput(
                 backend.library_state, input, output, weight, bias,
-                self.running_mean, self.running_var, self._save_mean,
-                self._save_std, self.training, self.momentum, self.eps)
+                self.running_mean, self.running_var, _save_mean,
+                _save_std, self.training, self.momentum, self.eps)
+
+        if self.requires_grad:
+            self._save_mean = _save_mean
+            self._save_std = _save_std
 
         return output
 
diff --git a/torch/nn/_functions/conv.py b/torch/nn/_functions/conv.py
index 9e8bb63..a3f44ec 100644
--- a/torch/nn/_functions/conv.py
+++ b/torch/nn/_functions/conv.py
@@ -91,10 +91,15 @@
                 self._cudnn_info = torch._C._cudnn_convolution_full_forward(
                     input, weight, bias, output, self.padding, self.stride, self.dilation,
                     self.groups, cudnn.benchmark)
+            if not self.requires_grad:
+                del self._cudnn_info
             return output
 
         self._bufs = [[] for g in range(self.groups)]
-        return self._thnn('update_output', input, weight, bias)
+        output = self._thnn('update_output', input, weight, bias)
+        if not self.requires_grad:
+            del self._bufs
+        return output
 
     def _grad_input(self, input, weight, grad_output):
         if self.use_cudnn:
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index d84ce94..c80cd33 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -143,6 +143,9 @@
         else:
             self.save_for_backward(input, *params)
 
+        if not self.requires_grad:
+            del self.buffers
+
         getattr(self._backend, update_output.name)(self._backend.library_state, input, output, *args)
         return output