blob: 673e4125a686d7006a11ba40de824efa5f1c0154 [file] [log] [blame]
"""
This util takes the op dependency graph of ATen and the list of root ops, and
outputs all transitive dependencies of the root ops. It is invoked from cmake
for custom mobile build.
"""
import argparse
import yaml
from collections import defaultdict
def canonical_name(opname):
# Skip the overload name part as it's not supported by code analyzer yet.
return opname.split('.', 1)[0]
def load_op_dep_graph(fname):
with open(fname, 'r') as stream:
result = defaultdict(set)
for op in yaml.safe_load(stream):
op_name = canonical_name(op['name'])
for dep in op.get('depends', []):
dep_name = canonical_name(dep['name'])
result[op_name].add(dep_name)
return result
def load_root_ops(fname):
result = []
with open(fname, 'r') as stream:
for op in yaml.safe_load(stream):
result.append(canonical_name(op))
return result
def gen_transitive_closure(dep_graph, root_ops):
result = set(root_ops)
queue = root_ops[:]
while queue:
cur = queue.pop()
for dep in dep_graph.get(cur, []):
if dep not in result:
result.add(dep)
queue.append(dep)
return ' '.join(sorted(result))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Util to produce transitive dependencies for custom build')
parser.add_argument(
'--op-dependency',
help='input yaml file of op dependency graph')
parser.add_argument(
'--root-ops',
help='input yaml file of root (directly used) operators')
args = parser.parse_args()
deps = load_op_dep_graph(args.op_dependency)
root_ops = load_root_ops(args.root_ops)
print(gen_transitive_closure(deps, root_ops))