blob: bbf1a510a8f416ed9802023afb0bae88cf8419f5 [file] [log] [blame]
import argparse
from codegen_outofplacebatching import deindent, get_signatures, gen_unwraps
def get_signature(op, path):
signatures = get_signatures(path, include_op=True)
result = [sig for sig in signatures if sig[0] == op]
if len(result) != 1:
raise ValueError("")
return result[0]
def gen_return_sig(return_t):
if len(return_t) == 1:
return return_t[0]
return f'std::tuple<{".".join(return_t)}>'
def gen_args_sig(args_t):
args = [f'{typ} {argname}' for typ, argname in args_t]
return ', '.join(args)
def gen_args_list(args_t):
args = [f'{argname}' for _, argname in args_t]
return ', '.join(args)
def gen_plumbing(signature):
# "add.Tensor"
op, return_t, args_t = signature
maybe_op_and_variant = op.split('.')
if len(maybe_op_and_variant) == 1:
op = maybe_op_and_variant[0]
variant = ''
opname = op
else:
op, variant = maybe_op_and_variant
opname = f'{op}_{variant}'
if op.endswith('_'):
raise ValueError('Codegen doesn\'t handle in-place ops')
arg_types, arg_names = zip(*args_t)
unwraps, _ = gen_unwraps(arg_types, arg_names)
result = deindent(f"""\
{gen_return_sig(return_t)} {opname}_plumbing({gen_args_sig(args_t)}) {{
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
{unwraps}
// Your logic here
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{op}", "{variant}");
return slow_fallback<{','.join(return_t)}>(op, {{ {gen_args_list(args_t)} }});
}}
""")
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Generate the batch rule plumbing for an op')
parser.add_argument('op',
help='the operator name (with overload name)')
parser.add_argument('path',
help='link to RegistrationDeclarations.h')
# Sample usage:
# gen_plumbing.py add.Tensor ~/pytorch/build/aten/src/ATen/RegistrationDeclarations.h
args = parser.parse_args()
signature = get_signature(args.op, args.path)
result = gen_plumbing(signature)
print(result)