blob: 3bebcb0e9576d94ececa707e5faf40fa66b7a4d5 [file] [log] [blame]
#include "Python.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/functions/accumulate_grad.h"
#include "torch/csrc/autograd/functions/tensor.h"
using namespace at;
namespace torch { namespace autograd {
Variable make_variable(at::Tensor data, std::shared_ptr<Function> grad_fn) {
// TODO: If you ever want to support returning an undefined tensor from
// a function, you'll have to uncomment the line below. Not sure if
// we actually want to support this.
// if (!data.defined()) return Variable();
TORCH_ASSERT(grad_fn);
int output_nr = grad_fn->num_inputs++;
return make_variable(std::move(data), output_nr, std::move(grad_fn));
}
VariableImpl::VariableImpl(Tensor data_, bool requires_grad, int output_nr, std::shared_ptr<Function> grad_fn)
: TensorImpl(getType(data_))
, data(std::move(data_))
, grad()
, _grad_fn(std::move(grad_fn))
, version_counter()
, _requires_grad(requires_grad)
, is_view(false)
, output_nr(output_nr)
, pyobj(nullptr) {
TORCH_ASSERTM(!_grad_fn || !_requires_grad, "_requires_grad should be false if grad_fn is set");
if (!data.defined()) {
throw std::runtime_error("data is undefined");
}
}
VariableImpl::~VariableImpl() {
}
const char * VariableImpl::toString() const {
return "Variable";
}
IntList VariableImpl::sizes() const {
return data.sizes();
}
IntList VariableImpl::strides() const {
return data.strides();
}
int64_t VariableImpl::dim() const {
return data.dim();
}
const char * VariableImpl::typeString() {
return "VariableType";
}
void * VariableImpl::unsafeGetTH(bool retain) {
return data.unsafeGetTH(retain);
}
std::unique_ptr<at::Storage> VariableImpl::storage() {
return data.storage();
}
Scalar VariableImpl::localScalar() {
return data.pImpl->localScalar();
}
std::shared_ptr<Function> VariableImpl::get_grad_accumulator() {
if (_grad_fn) {
throw std::logic_error("get_grad_accumulator() should be only called on leaf Variables");
}
if (!_requires_grad) {
return nullptr;
}
std::lock_guard<std::mutex> lock(mutex);
auto result = grad_accumulator.lock();
if (result) return result;
result = std::make_shared<AccumulateGrad>(Variable(this, true));
grad_accumulator = result;
return result;
}
VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, int output_nr,
std::shared_ptr<Function> grad_fn)
: VariableImpl(std::move(data_), false, output_nr, std::move(grad_fn))
, base(std::move(base_))
, attr_version(0) {
TORCH_ASSERTM(base.defined(), "base is undefined");
if (base.is_view()) {
base = base.base();
}
is_view = true;
version_counter = base.version_counter();
attr_version = version_counter.current_version();
}
std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
std::lock_guard<std::mutex> lock(mutex);
if (!_grad_fn && !base.requires_grad()) {
return _grad_fn;
}
auto current_version = version_counter.current_version();
if (attr_version != current_version) {
TORCH_ASSERT(output_nr == 0);
auto fn = std::make_shared<generated::AsStridedBackward>();
fn->self_geometry = TensorGeometry(base);
fn->size = sizes();
fn->stride = strides();
fn->storage_offset = data.storage_offset();
fn->set_flags(Function::flags(base));
fn->num_inputs = 1;
_grad_fn = std::move(fn);
attr_version = current_version;
}
return _grad_fn;
}
void VariableViewImpl::rebase_history(int output_nr, std::shared_ptr<Function> grad_fn) {
TORCH_ASSERT(output_nr == 0);
TORCH_ASSERT(grad_fn);
TORCH_ASSERTM(grad_fn->num_inputs == 1, "Functions which modify views in-place must return a single Variable");
this->output_nr = output_nr;
base.output_nr() = 0;
base.get()->_grad_fn = std::make_shared<CopySlices>(
base, TensorGeometry(data), std::move(grad_fn));
get_grad_fn(); // trigger an update to the view's grad_fn
}
namespace {
struct VariableTypes {
VariableTypes() {
auto& context = at::globalContext();
for (int p = 0; p < static_cast<int>(Backend::NumOptions); ++p) {
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); s++) {
auto baseType = context.type_registry[p][s].get();
if (baseType && baseType->backend() != Backend::Undefined) {
auto id = static_cast<int>(baseType->ID());
types[id].reset(new VariableType(&context, baseType));
}
}
}
}
std::unique_ptr<Type> types[static_cast<int>(TypeID::NumOptions)];
};
} // anonymous namespace
Type* VariableImpl::getType(const Tensor& tensor)
{
if (!tensor.defined()) {
throw std::runtime_error("tensor is undefined");
}
return getType(tensor.type());
}
static VariableTypes vt;
Type* VariableImpl::getType(const Type& baseType)
{
return vt.types[static_cast<int>(baseType.ID())].get();
}
std::vector<Type*> VariableImpl::allTypes() {
std::vector<Type*> types;
for (int i = 0; i < static_cast<int>(TypeID::NumOptions); i++) {
if (vt.types[i]) {
types.push_back(vt.types[i].get());
}
}
return types;
}
Variable Variable::detach() const {
Variable detached = make_variable(data());
detached.version_counter() = version_counter();
return detached;
}
void Variable::detach_() {
if (is_view()) {
throw std::runtime_error("Can't detach views in-place. Use detach() instead");
}
get()->_requires_grad = false;
output_nr() = 0;
get()->_grad_fn = nullptr;
}
}} // namespace torch::autograd