blob: 32d505cda7bf4eea7f34c1ca6559634d0a4e11e8 [file] [log] [blame]
from typing import Union
from tools.codegen.model import (NativeFunction, NativeFunctionsGroup)
from tools.codegen.api.lazy import LazyIrSchema, isValueType
from tools.codegen.api.types import OptionalCType
def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
emplace_arguments = []
for value in schema.positional_arg_types:
if isValueType(value.type):
if isinstance(value.type, OptionalCType):
emplace_arguments.append(f"has_{value.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
continue
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
continue
emplace_arguments.append(f'"{value.name}", {value.name}_')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments])
emplace_kwarg_values = [f'loctx->GetOutputOp(operand({i}))' for i in range(len(schema.keyword_values))]
emplace_kwarg_scalars = [f'"{t.name}", {t.name}_' for t in schema.keyword_scalars]
assert len(schema.keyword_values) == 0, "TODO the logic for operand(i) is broken if there are kw values"
emplace_kwarguments = "\n ".join(
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])
return f"""\
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
// TODO: need to call GenerateClone sometimes? Or else return LowerBuiltIn() directly
return {schema.aten_name}_out;
"""