blob: b0b45245717fa65bab4910fce8817199f1b6b36b [file] [log] [blame]
#pragma once
// Function is an abstract class that represents a single operation from one or
// more variables to one more or variables.
//
// Subclasses may represent "forward" or "backward" operations (i.e functions
// and their derivatives). Some functions may be used as both.
#include <Python.h>
#include "torch/csrc/autograd/function_hook.h"
#include <THPP/THPP.h>
#include <memory>
#include <vector>
namespace torch { namespace autograd {
struct Function;
struct Variable;
using tensor_list = std::vector<std::unique_ptr<thpp::Tensor>>;
using variable_list = std::vector<std::shared_ptr<Variable>>;
using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>;
// State used to create "backward" functions
struct FunctionFlags {
// Roughly speaking, is_executable corresponds to requires_grad.
// See http://pytorch.org/docs/notes/autograd.html for more details:
// both is_executable and is_volatile specify whether or not backwards
// gradient computation will be performed for a function, but they differ in
// their precedence.
bool is_executable = false;
bool is_volatile = false;
// What functions take the output of this function as input.
// There is one function per output of this function.
function_list next_functions;
};
struct Function {
Function()
: num_inputs(0)
, next_functions()
, is_executable(false)
, is_stochastic(false)
, pre_hooks()
, post_hooks()
{}
Function(FunctionFlags&& flags)
: num_inputs(0)
, next_functions(std::move(flags.next_functions))
, is_executable(flags.is_executable)
, is_stochastic(false)
, pre_hooks()
, post_hooks()
{}
Function(const Function& other) = delete;
Function(Function&& other) = delete;
virtual ~Function() {}
// Implements the operation
virtual variable_list apply(const variable_list& inputs) = 0;
// Computes is_executable, is_volatile, and next_functions from a list
// of input variables
static FunctionFlags flags(const variable_list& inputs);
// Releases saved variables if the operation won't be reused
virtual inline void releaseVariables() {}
// Function name for debugging
virtual std::string name();
inline bool should_compute_output(int i) const {
auto& fn = next_functions[i].first;
return fn && fn->is_executable;
}
inline void set_flags(FunctionFlags&& flags) {
is_executable = flags.is_executable;
next_functions = std::move(flags.next_functions);
}
int num_inputs;
function_list next_functions;
bool is_executable;
bool is_stochastic;
std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
};
}} // namespace torch::autograd