blob: bfe039d6ddc93916c1a17fbed14d013b16b05043 [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Tuple, List
import re
def num_leading_spaces(line: str) -> int:
return len(line) - len(line.lstrip())
def min_leading_spaces(lines):
num_spaces = [num_leading_spaces(line) for line in lines if len(line) > 0]
if len(num_spaces) == 0:
return None
return min(num_spaces)
def deindent(code: str) -> str:
lines = code.split('\n')
mls = min_leading_spaces(lines)
lines = [line[mls:] for line in lines]
return '\n'.join(lines)
def indent(code: str, num) -> str:
lines = code.split('\n')
indented_lines = [' ' * num + line for line in lines]
indented_lines[0] = lines[0]
return '\n'.join(indented_lines)
def is_tensor(typ: str) -> bool:
if typ == 'Tensor':
return True
if typ == 'const Tensor &':
return True
return False
def is_optional_tensor(typ: str) -> bool:
if typ == 'c10::optional<Tensor>':
return True
if typ == 'const c10::optional<Tensor> &':
return True
return False
def is_vector_tensor(typ: str) -> bool:
# (chilli): I don't really understand why there's 2 dots in front?
return (typ == '::std::vector<Tensor>')
def add_bdim_after_tensor(types: Tuple[str]) -> Tuple[str]:
result = []
for typ in types:
result.append(typ)
if is_tensor(typ) or is_optional_tensor(typ) or is_vector_tensor(typ):
result.append('c10::optional<int64_t>')
return tuple(result)
def batch_rule_type(
op_returns: Tuple[str],
op_args: Tuple[str],
unique_count: int) -> Tuple[str, str]:
returns = add_bdim_after_tensor(op_returns)
args = add_bdim_after_tensor(op_args)
br_t = f'batch_rule_{unique_count}_t'
result = f"typedef std::tuple<{','.join(returns)}> (*{br_t})({', '.join(args)});"
return result, br_t
def unwrap_tensor(name: str) -> List[str]:
result = f"""\
Tensor {name}_value;
optional<int64_t> {name}_bdim;
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, cur_level);"""
return deindent(result).split('\n')
def unwrap_optional_tensor(name: str) -> List[str]:
result = f"""\
optional<Tensor> {name}_value;
optional<int64_t> {name}_bdim;
if ({name}) {{
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), cur_level);
}}"""
return deindent(result).split('\n')
def gen_unwraps(arg_types, arg_names):
tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
optional_tensors = [name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)]
unwraps = []
for tensor in tensors:
unwraps += unwrap_tensor(tensor)
for opt_tensor in optional_tensors:
unwraps += unwrap_optional_tensor(opt_tensor)
unwraps = ('\n' + ' ' * 6).join(unwraps)
unwrapped_arg_list = []
for arg in arg_names:
if arg in tensors or arg in optional_tensors:
unwrapped_arg_list += [f'{arg}_value', f'{arg}_bdim']
else:
unwrapped_arg_list.append(arg)
return unwraps, unwrapped_arg_list
def lower(returns: Tuple[str], args: List[Tuple[str, str]], unique_count: int) -> str:
arg_types, arg_names = zip(*args)
batch_rule_typedef, batch_rule_t = batch_rule_type(returns, arg_types, unique_count)
return_t = returns[0] if len(returns) == 1 else f'std::tuple<{",".join(returns)}>'
args_t = ', '.join(arg_types)
arg_list = ', '.join([f'{arg[0]} {arg[1]}' for arg in args])
unwraps, unwrapped_arg_list = gen_unwraps(arg_types, arg_names)
idx = 0
wrapped_returns = []
for ret in returns:
if is_tensor(ret):
wrapped_returns.append(f'makeBatched(std::get<{idx}>(results), std::get<{idx + 1}>(results), cur_level)')
idx += 2
elif is_vector_tensor(ret):
wrapped_returns.append(f'makeBatchedVector(std::get<{idx}>(results), std::get<{idx + 1}>(results), cur_level)')
idx += 2
else:
wrapped_returns.append(f'std::get<{idx}>(results)')
idx += 1
if len(wrapped_returns) == 1:
wrapped_returns = f'return {wrapped_returns[0]};'
else:
wrapped_returns = f'return std::make_tuple({", ".join(wrapped_returns)});'
result = f"""\
{batch_rule_typedef}
template <>
{return_t} lowerToNextLayer<{batch_rule_t},{return_t},{args_t}>(
{batch_rule_t} batch_rule,
{arg_list}
) {{
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
{unwraps}
auto results = batch_rule({', '.join(unwrapped_arg_list)});
{wrapped_returns}
}}"""
return deindent(result)
def parse_return(return_t):
if 'std::tuple' not in return_t:
return (return_t,)
m = re.match(r'std::tuple<(.*)>', return_t)
if m is None:
m = re.match(r'::std::tuple<(.*)>', return_t)
return tuple([x.strip() for x in m.group(1).split(',')])
def parse_args(args_t):
args = args_t.split(',')
result = []
for arg in args:
split_idx = arg.rfind(' ')
result.append((arg[:split_idx].strip(), arg[split_idx:].strip()))
return tuple(result)
def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', include_op=False):
with open(path, 'r') as f:
txt = f.read()
lines = txt.split('\n')
schemas = []
for line in lines:
if 'void' in line:
continue
if 'std::array' in line:
continue
m = re.match(r'(.*) \w+\((.*)\); // {"schema": "aten::(\w+\.?\w*)\(.*', line)
if m is None:
continue
return_t = m.group(1)
args_t = m.group(2)
op = m.group(3)
# TODO: some namedtuple return. Also, testing for codegen
if include_op:
result = (op, parse_return(return_t), parse_args(args_t))
else:
result = (parse_return(return_t), parse_args(args_t))
schemas.append(result)
return tuple(schemas)
def is_schema_outplace(schema):
returns, args = schema
for arg in args:
typ, _ = arg
if typ == 'Tensor &' or typ == "TensorList":
return False
types, _ = zip(*args)
if all(not is_tensor(typ) for typ in types):
return False
for ret in returns:
if ret == "std::vector<Tensor>":
return False
if ret == "const Tensor &":
return False
if ret == "Tensor &":
return False
return True
def get_hash(schema):
ret_t, args = schema
args_t, _ = tuple(zip(*args))
return (ret_t, args_t)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path',
default='build/aten/src/ATen/RegistrationDeclarations.h',
help='link to RegistrationDeclarations.h')
args = parser.parse_args()
schemas = get_signatures(args.path)
schemas = [schema for schema in schemas if is_schema_outplace(schema)]
unique_schemas = {}
for schema in schemas:
unique_schemas[get_hash(schema)] = schema
schemas = list(unique_schemas.values())
codes = [lower(*schema, i) for i, schema in enumerate(schemas)]
print("#include <functorch/csrc/OutOfPlacePlumbing.h>")
print("#include <functorch/csrc/PlumbingHelper.h>")
print("#include <functorch/csrc/Constants.h>")
print("")
print("namespace at { namespace functorch {")
for code in codes:
print(code)
print('')
print("}}")