blob: c592388d417ebf8531b51ca8058d140dd1ab2f98 [file] [log] [blame]
#include "function.h"
#include <string>
#include "variable.h"
namespace torch { namespace autograd {
auto Function::flags(const variable_list& inputs) -> FunctionFlags {
int num_inputs = inputs.size();
FunctionFlags f;
f.is_executable = false;
f.is_volatile = false;
f.next_functions.resize(num_inputs);
for (int i = 0; i != num_inputs; ++i) {
auto& var = inputs[i];
if (var) {
f.is_executable |= var->requires_grad;
f.is_volatile |= var->is_volatile;
if (var->grad_fn) {
f.next_functions[i] = std::make_pair<>(var->grad_fn, var->output_nr);
} else {
f.next_functions[i] = std::make_pair<>(var->get_grad_accumulator(), 0);
}
}
}
f.is_executable &= !f.is_volatile;
return f;
}
auto Function::name() -> std::string {
return std::string(typeid(*this).name());
}
}} // namespace torch::autograd