Fix leaf Variable handling in autograd
diff --git a/test/test_autograd.py b/test/test_autograd.py
index a84f1cc..d1e983c 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -399,6 +399,67 @@
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)
+ def test_return_leaf(self):
+ class Identity(Function):
+ def forward(self, a, b):
+ return a, a + b
+
+ def backward(self, grad_a, grad_b):
+ return grad_a + grad_b, grad_b
+
+ class Inplace(InplaceFunction):
+ def forward(self, a, b):
+ self.mark_dirty(a)
+ return a.add_(b), b + 2
+
+ def backward(self, grad_a, grad_b):
+ return grad_a, grad_a + grad_b
+
+ x = Variable(torch.randn(5, 5), requires_grad=True)
+ y = Variable(torch.randn(5, 5), requires_grad=True)
+
+ q, p = Identity()(x, y)
+ # Make sure hooks only receive grad from usage of q, not x.
+ q.register_hook(
+ 'test', lambda grad: self.assertEqual(grad, torch.ones(5, 5)))
+ (q + p + x).sum().backward()
+ self.assertEqual(x.grad, torch.ones(5, 5) * 3)
+ self.assertEqual(y.grad, torch.ones(5, 5))
+ del q, p # these need to be freed, or next part will raise an error
+
+ def test_return_leaf_inplace(self):
+ class Inplace(InplaceFunction):
+ def forward(self, a, b):
+ self.mark_dirty(a)
+ return a.add_(b), b + 2
+
+ def backward(self, grad_a, grad_b):
+ return grad_a, grad_a + grad_b
+
+ x = Variable(torch.randn(5, 5))
+ y = Variable(torch.randn(5, 5), requires_grad=True)
+
+ fn = Inplace(True)
+ q, p = fn(x, y)
+ self.assertIs(q, x)
+ self.assertIs(q.creator, fn)
+ self.assertTrue(q.requires_grad)
+ q.sum().backward()
+ self.assertEqual(y.grad, torch.ones(5, 5))
+
+ def test_leaf_assignment(self):
+ x = Variable(torch.randn(5, 5))
+ y = Variable(torch.randn(5), requires_grad=True)
+ z = Variable(torch.randn(5), requires_grad=True)
+
+ x[0] = y
+ x[1] = 2 * z
+ self.assertTrue(x.requires_grad)
+ self.assertIsNot(x.creator, None)
+ x.sum().backward()
+ self.assertEqual(y.grad, torch.ones(5))
+ self.assertEqual(z.grad, torch.ones(5) * 2)
+
def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared
# in one of the initial versions of autograd. Gradients tensors were
@@ -480,18 +541,18 @@
class MyFn(Function):
def forward(self, input):
self.save_for_backward(None, input, None)
- return input
+ return input * input
def backward(self, grad_output):
n1, input, n2 = self.saved_tensors
test_case.assertIsNone(n1)
test_case.assertIsNone(n2)
- return input * grad_output
+ return 2 * input * grad_output
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
- self.assertEqual(x.grad, x.data)
+ self.assertEqual(x.grad, 2 * x.data)
def test_too_many_grads(self):
class MyFn(Function):
diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp
index 269bd7f..30fdbf2 100644
--- a/torch/csrc/autograd/function.cpp
+++ b/torch/csrc/autograd/function.cpp
@@ -2,6 +2,7 @@
#include <structmember.h>
#include <unordered_map>
+#include <unordered_set>
#include <exception>
#include "THP.h"
@@ -101,7 +102,8 @@
using t2var_type = std::unordered_map<PyObject *, THPVariable *>;
-static void _mark_dirty(THPFunction *self, t2var_type &t2var)
+static void _mark_dirty(THPFunction *self, t2var_type &t2var,
+ std::unordered_set<PyObject *> &dirty_inputs)
{
// Increase versions of modified tensors
if (!self->dirty_tensors) return;
@@ -112,6 +114,7 @@
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
for (int i = 0; i < num_dirty; i++) {
PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i);
+ dirty_inputs.insert(tensor);
THPVariable *variable;
try {
variable = t2var.at(tensor);
@@ -135,7 +138,8 @@
}
static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
- PyObject *raw_output, PyObject *outputs)
+ std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
+ PyObject *outputs)
{
// Wrap outputs in Variables
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
@@ -161,16 +165,48 @@
Py_INCREF(self);
input_var->creator = (PyObject*)self;
} else {
- // If it's a leaf it's not as simple. Leaves will raise an error in
- // backward if they've been changed, or they're no longer leaves. In
- // some cases (e.g. broadcast) it's perfectly valid to return the same
- // tensor untouched, so instead of moving it we're going to create a
- // copy and join their version counters. This works for broadcast,
- // and if the use wasn't valid we'll still detect an error, because
- // the leaf will have a version != 0.
- output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self, self->requires_grad);
- if (!output_var) throw python_error();
- output_var->version_counter->join_with(*input_var->version_counter);
+ // If the Variable has been changed, we have to move it after the
+ // current function to ensure the gradient is computed correctly.
+ // There are two cases now:
+ // 1. If it requires grad, it is an error, and this will be caught
+ // when its _do_backward is called, because it won't be a leaf anymore.
+ // Also we'll change its version.
+ // 2. If it doesn't require grad, we can safely move it in the graph,
+ // because its _do_backward will never be called.
+ if (dirty_inputs.count(output) > 0) {
+ Py_INCREF(input_var);
+ output_var = input_var;
+ Py_INCREF(self);
+ output_var->creator = (PyObject*)self;
+ if (!output_var->requires_grad && self->requires_grad) {
+ // Now, there's another subtlety. We move the input in the graph
+ // and we change it's requires_grad to True. However, remember
+ // that we're still holding a reference to is as a previous
+ // function. Backward engine will think that it was really a
+ // leaf that initialy did require grad and call its _do_backward
+ // and that will throw. Because of this, we need to allocate
+ // a dummy leaf that doesn't require grad and put it as our
+ // previous function.
+ output_var->requires_grad = self->requires_grad;
+ PyObject* dummy_prev_fn = THPVariable_New(output, NULL, false);
+ if (!dummy_prev_fn) throw python_error();
+ self->previous_functions[i] = THPFunctionPtr(dummy_prev_fn, 0);
+ }
+ } else {
+ // An input has been returned, but it wasn't modified. It's better
+ // not to move the Variable, because there are some legitimate cases
+ // where making it non-leaf would break stuff (e.g. broadcast). Also,
+ // returning the input Variable is also not a very good option,
+ // because if someone uses hooks, they will fire with grads from
+ // all usages, not only from usages of this output. This is why we'll
+ // just return a copy and join their version counters. This has
+ // a side-effect of making in-place ops on any of these Variables an
+ // immediate error, but it would be raised anyway, once someone
+ // calls backward.
+ output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self, self->requires_grad);
+ if (!output_var) throw python_error();
+ output_var->version_counter->join_with(*input_var->version_counter);
+ }
}
}
if (!output_var) throw python_error();
@@ -390,8 +426,9 @@
self->previous_functions[i] = THPFunctionPtr(prev_fn, input_var->output_nr);
}
- _mark_dirty(self, t2var);
- _wrap_outputs(self, t2var, raw_output, outputs);
+ std::unordered_set<PyObject *> dirty_inputs;
+ _mark_dirty(self, t2var, dirty_inputs);
+ _wrap_outputs(self, t2var, dirty_inputs, raw_output, outputs);
_join_version_counters(self, t2var);
if (self->requires_grad ||
PyObject_IsInstance((PyObject*)self, THPStochasticFunctionClass)) {
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index b858203..a4606d8 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -42,13 +42,15 @@
}
// This function DOES NOT steal a reference to data and creator
+// To create a leaf Variable pass NULL as creator.
PyObject * THPVariable_New(PyObject *data, PyObject *creator, char requires_grad)
{
if (num_cached > 0) {
Py_INCREF(data);
- Py_INCREF(creator);
+ Py_XINCREF(creator);
return (PyObject*)pop_cache(data, creator, requires_grad);
}
+ creator = creator ? creator : Py_None;
return PyObject_CallFunction(THPVariableClass, "OObb", data, creator, (char)0, requires_grad);
}