| #pragma once |
| |
| // Function is an abstract class that represents a single operation from one or |
| // more variables to one more or varaibles. |
| // |
| // Subclasses may represent "forward" or "backward" operations (i.e functions |
| // and their derivatives). Some functions may be used as both. |
| |
| #include <memory> |
| #include <THPP/THPP.h> |
| #include <vector> |
| |
| #include "torch/csrc/autograd/saved_variable.h" |
| |
| namespace torch { namespace autograd { |
| |
| struct Function; |
| struct Variable; |
| |
| using tensor_list = std::vector<std::unique_ptr<thpp::Tensor>>; |
| using variable_list = std::vector<std::shared_ptr<Variable>>; |
| using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>; |
| |
| // State used to create "backward" functions |
| struct FunctionFlags { |
| bool requires_grad; |
| bool is_volatile; |
| function_list previous_functions; |
| }; |
| |
| struct Function { |
| Function() |
| : num_outputs(0) |
| , previous_functions() |
| , requires_grad(false) |
| , is_volatile(false) |
| , is_stochastic(false) |
| {} |
| |
| Function(FunctionFlags flags) |
| : num_outputs(0) |
| , previous_functions(std::move(flags.previous_functions)) |
| , requires_grad(flags.requires_grad) |
| , is_volatile(flags.is_volatile) |
| , is_stochastic(false) |
| {} |
| |
| Function(const Function& other) = delete; |
| Function(Function&& other) = delete; |
| virtual ~Function() {} |
| |
| // Implements the operation |
| virtual variable_list apply(const variable_list& inputs) = 0; |
| |
| // Computes requires_grad, is_volatile, and previous_functions from a list |
| // of input variables |
| static FunctionFlags flags(const variable_list& inputs); |
| |
| // Releases saved variables if the operation won't be reused |
| virtual inline void releaseVariables() {} |
| |
| // These variables are usually only meaningful for "backward" functions. |
| // num_outputs is the number of outputs of corresponding "forward" function; |
| // it's actually the number of inputs of this function. |
| int num_outputs; |
| function_list previous_functions; |
| bool requires_grad; |
| bool is_volatile; |
| bool is_stochastic; |
| }; |
| |
| |
| }} // namespace torch::autograd |