| #include "function.h" |
| |
| #include <string> |
| |
| #include "variable.h" |
| |
| namespace torch { namespace autograd { |
| |
| auto Function::flags(const variable_list& inputs) -> FunctionFlags { |
| int num_inputs = inputs.size(); |
| FunctionFlags f; |
| f.is_executable = false; |
| f.is_volatile = false; |
| f.next_functions.resize(num_inputs); |
| for (int i = 0; i != num_inputs; ++i) { |
| auto& var = inputs[i]; |
| if (var) { |
| f.is_executable |= var->requires_grad; |
| f.is_volatile |= var->is_volatile; |
| if (var->grad_fn) { |
| f.next_functions[i] = std::make_pair<>(var->grad_fn, var->output_nr); |
| } else { |
| f.next_functions[i] = std::make_pair<>(var->get_grad_accumulator(), 0); |
| } |
| } |
| } |
| f.is_executable &= !f.is_volatile; |
| return f; |
| } |
| |
| auto Function::name() -> std::string { |
| return std::string(typeid(*this).name()); |
| } |
| |
| }} // namespace torch::autograd |