blob: 0865aca6df07ccb8696db006cc20a8cdce7e59b8 [file] [log] [blame]
from collections import deque, defaultdict
from torch._C import _ImperativeEngine as ImperativeEngine
from .variable import Variable
class BasicEngine(object):
def _compute_dependencies(self, function):
dependencies = defaultdict(int)
seen = {function}
queue = [function]
while len(queue) > 0:
fn = queue.pop()
for prev_fn, output_nr in fn.previous_functions:
if not prev_fn.requires_grad or isinstance(prev_fn, Variable):
continue
dependencies[prev_fn] += 1
if prev_fn not in seen:
queue.append(prev_fn)
seen.add(prev_fn)
return dependencies
def _free_backward_dependency(self, dependencies, prev_fn):
dependencies[prev_fn] -= 1
if dependencies[prev_fn] == 0:
del dependencies[prev_fn]
return True
return False
def _add_grad(self, need_copy, prev_grad, output_nr, d_prev_fn):
copy_id = (id(prev_grad), output_nr)
if not prev_grad[output_nr]:
prev_grad[output_nr] = d_prev_fn
need_copy.add(copy_id)
else:
grad_tensor = prev_grad[output_nr]
if copy_id in need_copy:
need_copy.remove(copy_id)
grad_tensor = grad_tensor.clone()
prev_grad[output_nr] = grad_tensor
grad_tensor.add_(d_prev_fn)
def run_backward(self, variable, grad, retain_variables):
if variable.creator is None:
variable._do_backward((grad,), retain_variables)
return
initial_grad = [None for _ in range(variable.creator.num_outputs)]
initial_grad[variable.output_nr] = grad
ready = deque([(variable.creator, initial_grad)])
not_ready = {}
need_copy = set()
dependencies = self._compute_dependencies(variable.creator)
while len(ready) > 0:
fn, grad = ready.pop()
grad_input = fn._do_backward(tuple(grad), retain_variables)
for (prev_fn, output_nr), d_prev_fn in zip(fn.previous_functions, grad_input):
if not prev_fn.requires_grad:
# TODO: check that d_prev_fn is None and warn otherwise
continue
if isinstance(prev_fn, Variable):
prev_fn._do_backward((d_prev_fn,), retain_variables)
continue
is_ready = self._free_backward_dependency(dependencies, prev_fn)
if is_ready:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
self._add_grad(need_copy, prev_grad, output_nr, d_prev_fn)
else:
if prev_fn.num_outputs != 1:
raise RuntimeError("one of the function outputs "
"wasn't used - this is an error not, but "
"it's going to be fixed soon")
prev_grad = (d_prev_fn,)
ready.appendleft((prev_fn, prev_grad))
else:
if prev_fn in not_ready:
prev_grad = not_ready[prev_fn]
else:
prev_grad = [None for _ in range(prev_fn.num_outputs)]
self._add_grad(need_copy, prev_grad, output_nr, d_prev_fn)
not_ready[prev_fn] = prev_grad