Make requires_grad read only (except for leaves)
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index 109fadd..be9b0f8 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -43,7 +43,7 @@
# Save the input, so _save_for_backward can access it
self.input = input
if not is_volatile:
- self.needs_input_grad = tuple(arg.requires_grad for arg in input)
+ self.needs_input_grad = tuple(arg._requires_grad for arg in input)
self.requires_grad = any(self.needs_input_grad)
self.previous_functions = [(arg.creator or arg, id(arg)) for arg in input]
@@ -73,7 +73,7 @@
if self.non_differentiable is not None:
for var in output:
if var.data in self.non_differentiable:
- var.requires_grad = False
+ var._requires_grad = False
del self.input # Remove unnecessary references to input
del self.non_differentiable # and output
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index b053503..1045f0b 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -24,10 +24,13 @@
self.creator = creator
self.volatile = volatile
self.dirty = False
- self.requires_grad = (not volatile) and requires_grad
+ self._requires_grad = (not volatile) and requires_grad
self._data = tensor
self._grad = None
self.backward_hooks = OrderedDict()
+ if not torch.is_tensor(tensor):
+ raise ValueError("Variable objects can only wrap tensors but got " +
+ torch.typename(tensor))
@property
def grad(self):
@@ -42,6 +45,17 @@
raise RuntimeError('Accessing data of a dirty variable!')
return self._data
+ @property
+ def requires_grad(self):
+ return self._requires_grad
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ if self.creator is not None:
+ raise RuntimeError("you can only change requires_grad flags of "
+ "leaf variables")
+ self._requires_grad = value
+
def mark_dirty(self):
self.dirty = True
self._data = None