Improve autograd memory usage (#859)
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 8c0a52d..b66d91c 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -486,6 +486,12 @@
if (self->cdata.requires_grad || self->cdata.is_stochastic) {
_save_variables(self, t2var);
_mark_non_differentiable(self, t2var);
+ } else {
+ // Remove unnecessary attributes
+ Py_XDECREF(self->to_save);
+ self->to_save = NULL;
+ Py_XDECREF(self->non_differentiable);
+ self->non_differentiable = NULL;
}
}
diff --git a/torch/nn/_functions/batchnorm.py b/torch/nn/_functions/batchnorm.py
index 46da184..db8f7a0 100644
--- a/torch/nn/_functions/batchnorm.py
+++ b/torch/nn/_functions/batchnorm.py
@@ -25,22 +25,26 @@
# temporary buffers used in forward and backward
num_features = input.size(1)
- self._save_mean = input.new(num_features)
- self._save_std = input.new(num_features)
+ _save_mean = input.new(num_features)
+ _save_std = input.new(num_features)
output = input.new(input.size())
if self.use_cudnn:
torch._C._cudnn_batch_norm_forward(
input, output, weight, bias,
- self.running_mean, self.running_var, self._save_mean,
- self._save_std, self.training, self.momentum, self.eps)
+ self.running_mean, self.running_var, _save_mean,
+ _save_std, self.training, self.momentum, self.eps)
else:
backend = type2backend[type(input)]
backend.BatchNormalization_updateOutput(
backend.library_state, input, output, weight, bias,
- self.running_mean, self.running_var, self._save_mean,
- self._save_std, self.training, self.momentum, self.eps)
+ self.running_mean, self.running_var, _save_mean,
+ _save_std, self.training, self.momentum, self.eps)
+
+ if self.requires_grad:
+ self._save_mean = _save_mean
+ self._save_std = _save_std
return output
diff --git a/torch/nn/_functions/conv.py b/torch/nn/_functions/conv.py
index 9e8bb63..a3f44ec 100644
--- a/torch/nn/_functions/conv.py
+++ b/torch/nn/_functions/conv.py
@@ -91,10 +91,15 @@
self._cudnn_info = torch._C._cudnn_convolution_full_forward(
input, weight, bias, output, self.padding, self.stride, self.dilation,
self.groups, cudnn.benchmark)
+ if not self.requires_grad:
+ del self._cudnn_info
return output
self._bufs = [[] for g in range(self.groups)]
- return self._thnn('update_output', input, weight, bias)
+ output = self._thnn('update_output', input, weight, bias)
+ if not self.requires_grad:
+ del self._bufs
+ return output
def _grad_input(self, input, weight, grad_output):
if self.use_cudnn:
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index d84ce94..c80cd33 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -143,6 +143,9 @@
else:
self.save_for_backward(input, *params)
+ if not self.requires_grad:
+ del self.buffers
+
getattr(self._backend, update_output.name)(self._backend.library_state, input, output, *args)
return output