blob: b98e386e3e8b7b2368dce8a6086c0f6ae9aa169f [file] [log] [blame]
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/function.h"
using namespace at;
namespace torch { namespace autograd {
SavedVariable::SavedVariable(const Variable& variable, Function* saved_for)
: SavedVariable() {
if (!variable.defined()) {
return;
}
data = variable.data();
requires_grad = variable.requires_grad();
is_volatile = variable.is_volatile();
expected_version = variable.current_version();
version = variable.get()->version_counter->new_saved_ref();
has_grad_fn = variable.grad_fn() != nullptr;
if (!has_grad_fn) {
grad_accumulator = variable.grad_accumulator();
}
if (variable.grad_fn().get() != saved_for) {
grad_fn = variable.grad_fn();
}
if (variable.tracing_state()) {
tracing_state.reset(new jit::tracer::ValueTracingState(*variable.tracing_state()));
}
}
auto SavedVariable::unpack(std::shared_ptr<Function> saved_for) const -> Variable {
if (!data.defined()) {
if (version) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
return Variable();
}
int current_version = **version;
if (expected_version != current_version) {
throw std::runtime_error("one of the variables "
"needed for gradient computation has been modified by an "
"inplace operation");
}
Variable var = make_variable(data, requires_grad, is_volatile);
if (has_grad_fn && !grad_fn) {
if (!saved_for) {
// If saving the grad_fn would create a circular reference, then it must
// be passed in to the unpack function.
throw std::runtime_error("No grad_fn for non-leaf saved variable");
}
var.grad_fn() = saved_for;
} else {
var.grad_fn() = grad_fn;
}
var.get()->version_counter->join_with(*version);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the graph).
if (requires_grad && !var.grad_fn() && grad_accumulator.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
var.get()->grad_accumulator = grad_accumulator;
if (tracing_state)
var.tracing_state().reset(new jit::tracer::ValueTracingState(*tracing_state));
return var;
}
auto SavedVariable::unpack_data(std::shared_ptr<Function> saved_for) const -> Tensor {
auto var = unpack(saved_for);
if (var.defined()) {
return var.data();
}
return Tensor();
}
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time, but the buffers have "
"already been freed. Specify retain_graph=True when calling backward "
"the first time.";
}} // namespace torch::autograd