blob: dee03398b2d5d88ef9757be63753585b6885bb04 [file] [log] [blame]
# Generates C++ autograd functions for the derivatives of ATen operations
#
# This writes two files:
# Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes
#
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
from typing import List, Sequence, Tuple
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
SavedAttribute, uses_retain_variables,
uses_single_grad)
from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT, intT,
doubleT, scalarT, stringT, boolT, intArrayRefT,
tensorListT, MutRefCType, ListCType, ArrayRefCType)
from tools.codegen.code_template import CodeTemplate
from tools.codegen.gen import FileManager
from tools.codegen.model import Argument
FUNCTION_DECLARATION = CodeTemplate("""\
struct TORCH_API ${op} : public ${superclass} {
using ${superclass}::${superclass};
variable_list apply(variable_list&& grads) override;
std::string name() const override { return "${op}"; }
void release_variables() override {
${thread_lock}
${release_variables}
}
${will_release_variables}
${saved_variables}
${saved_list_sizes}
};
""")
WILL_RELEASE_VARIABLES = CodeTemplate("""\
bool retain_variables = true;
void will_release_variables() override {
retain_variables = false;
}
""")
FUNCTION_DEFINITION = CodeTemplate("""\
variable_list ${op}::apply(variable_list&& grads) {
${thread_lock}
${asserts}
IndexRangeGenerator gen;
${compute_index_ranges}
variable_list grad_inputs(gen.size());
${body}
return grad_inputs;
}
""")
GRAD_INPUT_MASK = CodeTemplate("""\
auto grad_input_mask = std::array<bool, ${n}>{
${masks}
};\
""")
DERIVATIVE_SINGLE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
""")
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
""")
DERIVATIVE_MULTI = CodeTemplate("""\
if (should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
}
""")
# Generates python bindings
#
# This generates the definitions for:
# (1) The PyTypeObject for each backward grad_fn subclassing Node
# (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs)
# We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs
# Each PyGetSetDef has a function ptr to a getter, also defined here (3).
# (3) Getters for each of grad_fn's saved inputs and outputs.
#
PY_FUNCTION_DEFINITION = CodeTemplate("""\
static PyTypeObject ${op}Class;
addClass<${op}>(${op}Class, "${op}", ${op}_properties);
""")
PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate("""\
${all_getter_definitions}
static struct PyGetSetDef ${op}_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
${all_getsetdef_structs}
{nullptr} /* sentinel */
};
""")
PY_GETSETDEF_STRUCT = CodeTemplate("""\
{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""")
# Getter templates
GETTER_DEFINITION = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto prop = static_cast<${op}*>(self->cdata.get())->${name};
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_SAVEDVAR = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto *node = static_cast<${op}*>(self->cdata.get());
const auto& prop = node->${name}_;
if (node->${name}_released_) {
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
return nullptr;
}
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_OPT = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.value();
${body}
END_HANDLE_TH_ERRORS
}
""")
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate("""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.list.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.list.value();
${body}
END_HANDLE_TH_ERRORS
}
""")
# Getter body
GETTER_BODY_SAVEDVAR = """\
return THPVariable_Wrap(prop.unpack(self->cdata));
"""
GETTER_BODY_VEC_SAVEDVAR = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (int i = 0; i < prop.size(); i++) {
PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
}
return tup;
"""
GETTER_BODY_ARRAYREF_LONG = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (int i = 0; i < prop.size(); i++) {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
}
return tup;
"""
GETTER_BODY_ARRAYREF_DOUBLE = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (int i = 0; i < prop.size(); i++) {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
}
return tup;
"""
GETTER_BODY_INT64_T = """\
return PyLong_FromUnsignedLong((int64_t) prop);
"""
GETTER_BODY_DOUBLE = """\
return PyFloat_FromDouble((double) prop);
"""
GETTER_BODY_BOOL = """\
if (prop) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
"""
GETTER_BODY_STRING = """\
return PyUnicode_FromStringAndSize(prop.data(), prop.size());
"""
GETTER_BODY_SCALAR = """\
if (prop.isComplex()) {
auto cprop = prop.to<c10::complex<double>>();
return PyComplex_FromDoubles(cprop.real(), cprop.imag());
} else if (prop.isFloatingPoint()) {
return PyFloat_FromDouble(prop.to<double>());
} else if (prop.isIntegral(/*includeBool=*/false)) {
return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
if (prop.to<bool>()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
} else {
PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
return nullptr;
}
"""
MISC_GETTER_DEFS = {
OptionalCType(BaseCType(intT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}
# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def gen_autograd_functions_lib(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
gen_autograd_functions(out, differentiability_infos, template_path, "Functions")
def gen_autograd_functions_python(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
gen_autograd_functions(out, differentiability_infos, template_path, "python_functions")
def gen_autograd_functions(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
file_basename: str,
) -> None:
"""Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Node
for each every differentiable torch function.
"""
# only create an autograd function if we are actually going to calculate a derivative
infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))
py_function_initializers = list(map(lambda f: process_function(f, PY_FUNCTION_DEFINITION), infos))
py_function_props_and_getters = list(map(lambda f: process_function(f, PY_FUNCTION_PROPS_AND_GETTERS), infos))
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for suffix in ['.h', '.cpp']:
fname = file_basename + suffix
fm.write_with_template(fname, fname, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/' + fname,
'autograd_function_declarations': declarations,
'autograd_function_definitions': definitions,
'py_function_initializers': py_function_initializers,
'py_function_props_and_getters': py_function_props_and_getters
})
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
saved_variables: List[str] = []
release_variables: List[str] = []
saved_list_sizes: List[str] = []
unpack: List[str] = []
asserts: List[str] = []
compute_index_ranges: List[str] = []
getter_definitions: List[str] = []
py_getsetdef_structs: List[str] = []
for arg in info.args_with_derivatives:
if arg.type == 'at::TensorList' or arg.type == 'const c10::List<c10::optional<at::Tensor>> &':
size = f'{arg.name}_size_'
saved_list_sizes.append(f'size_t {arg.name}_size_;')
else:
size = '1'
compute_index_ranges.append(f'auto {arg.name}_ix = gen.range({size});')
def save_var(var: SavedAttribute, is_output: bool) -> None:
name = var.nctype.name
type = var.nctype.type
should_append_getsetdef = True
if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \
(type == BaseCType(scalarT) and is_output):
saved_variables.append(f'SavedVariable {name}_;')
release_variables.append(f'{name}_.reset_data();')
ptr = 'shared_from_this()' if is_output else ''
unpack.append(f'auto {name} = {name}_.unpack({ptr});')
getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_SAVEDVAR))
elif type == BaseCType(tensorListT):
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
saved_variables.append(f'bool {name}_released_ = false;')
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append(f'{name}_.clear();')
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
saved_variables.append(f'bool {name}_released_ = false;')
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append(f'{name}_.clear();')
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_opt_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
elif type == BaseCType(intArrayRefT):
saved_variables.append(f'std::vector<int64_t> {name};')
getter_definitions.append(GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
elif type == OptionalCType(BaseCType(intArrayRefT)):
saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
saved_variables.append(f'c10::OptionalArray<double> {name};')
getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE))
elif type == BaseCType(intT):
saved_variables.append(f'{type.cpp_type()} {name} = 0;')
getter_definitions.append(GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_INT64_T))
elif type == BaseCType(stringT):
saved_variables.append(f'std::string {name};')
getter_definitions.append(GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_STRING))
elif type == OptionalCType(BaseCType(stringT)):
saved_variables.append(f'c10::optional<std::string> {name};')
getter_definitions.append(GETTER_DEFINITION_OPT.substitute(
op=info.op, name=name, body=GETTER_BODY_STRING))
else:
saved_variables.append(f'{type.cpp_type()} {name};')
if type in MISC_GETTER_DEFS:
getter_def, body = MISC_GETTER_DEFS[type]
getter_definitions.append(getter_def.substitute(op=info.op, name=name, body=body))
else:
# Types we don't expose python bindings to yet:
# TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
# std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
should_append_getsetdef = False
if should_append_getsetdef:
py_getsetdef_structs.append(PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
for var in info.all_saved_inputs:
save_var(var, is_output=False)
for var in info.all_saved_outputs:
save_var(var, is_output=True)
# lock the mutex when we release variables and in Node::apply to protect thread safety
# see Note [Thread Safety on Autograd Node]
if len(release_variables) > 0:
thread_lock = 'std::lock_guard<std::mutex> lock(mutex_);'
else:
thread_lock = ''
if uses_retain_variables(info):
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
else:
will_release_variables = ''
body: List[str] = []
if uses_single_grad(info):
body.append('auto& grad = grads[0];')
def emit_derivative(
derivative: Derivative,
args_with_derivatives: Sequence[Binding],
) -> Tuple[bool, str]:
formula = derivative.formula
var_names = derivative.var_names
if len(var_names) == 1:
checks_any_grad_defined = False
if 'not_implemented' not in formula:
matching_args = [
arg for arg in args_with_derivatives
if arg.name == var_names[0]]
if len(matching_args) == 1:
# We can add undefined grad support if the input variable is a Tensor
arg = matching_args[0]
if isinstance(arg.argument, Argument) and str(arg.argument.type) in ('Tensor', 'Tensor?'):
formula = 'any_grad_defined ? (' + formula + ') : Tensor()'
checks_any_grad_defined = True
return (checks_any_grad_defined,
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula))
else:
if 'grad_input_mask' in formula:
masks = [f'should_compute_output({{ {n}_ix }}),' for n in var_names]
grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
else:
grad_input_mask = ''
idx_ranges = ', '.join(f'{n}_ix' for n in var_names)
copy_ranges: List[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges, copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask)
body.extend(unpack)
need_any_grad_defined_var = False
for derivative in info.derivatives:
checks_any_grad_defined, derivative_text = emit_derivative(derivative, info.args_with_derivatives)
body.append(derivative_text)
need_any_grad_defined_var |= checks_any_grad_defined
# Since single-output derivative formulas need to check if grads are
# defined, only perform the check once, before all the formulas
if need_any_grad_defined_var:
body.insert(-len(info.derivatives),
'bool any_grad_defined = any_variable_defined(grads);')
if info.name in UNTRACEABLE_FUNCTIONS:
superclass = 'Node'
else:
superclass = 'TraceableFunction'
all_getsetdef_structs = ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
all_getter_definitions = "\n".join(getter_definitions)
return template.substitute(
op=info.op,
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,
saved_list_sizes=saved_list_sizes,
asserts=asserts,
thread_lock=thread_lock,
will_release_variables=will_release_variables,
body=body,
superclass=superclass,
all_getter_definitions=all_getter_definitions,
all_getsetdef_structs=all_getsetdef_structs
)