| ## @package net_printer | 
 | # Module caffe2.python.net_printer | 
 |  | 
 |  | 
 |  | 
 |  | 
 |  | 
 | from caffe2.proto.caffe2_pb2 import OperatorDef, NetDef | 
 | from caffe2.python.checkpoint import Job | 
 | from caffe2.python.core import Net, ExecutionStep, Plan | 
 | from caffe2.python.task import Task, TaskGroup, WorkspaceType, TaskOutput | 
 | from collections import defaultdict | 
 | from contextlib import contextmanager | 
 | from copy import copy | 
 | from itertools import chain | 
 |  | 
 |  | 
 | class Visitor: | 
 |     @classmethod | 
 |     def register(cls, Type): | 
 |         if not(hasattr(cls, 'visitors')): | 
 |             cls.visitors = {} | 
 |         else: | 
 |             assert Type not in cls.visitors, \ | 
 |                 '{} already registered!'.format(Type) | 
 |  | 
 |         def _register(func): | 
 |             cls.visitors[Type] = func | 
 |             return func | 
 |  | 
 |         return _register | 
 |  | 
 |     def __call__(self, obj, *args, **kwargs): | 
 |         if obj is None: | 
 |             return | 
 |  | 
 |         Type = type(obj) | 
 |         if Type not in self.__class__.visitors: | 
 |             raise TypeError('%s: unsupported object type: %s' % ( | 
 |                 self.__class__.__name__, Type)) | 
 |  | 
 |         func = self.__class__.visitors[Type] | 
 |         return func(self, obj, *args, **kwargs) | 
 |  | 
 |  | 
 | class Analyzer(Visitor): | 
 |     PREFIXES_TO_IGNORE = {'distributed_ctx_init'} | 
 |  | 
 |     def __init__(self): | 
 |         self.workspaces = defaultdict(lambda: defaultdict(lambda: 0)) | 
 |         self.workspace_ctx = [] | 
 |  | 
 |     @property | 
 |     def workspace(self): | 
 |         return self.workspace_ctx[-1] | 
 |  | 
 |     @contextmanager | 
 |     def set_workspace(self, node=None, ws=None, do_copy=False): | 
 |         if ws is not None: | 
 |             ws = ws | 
 |         elif node is not None: | 
 |             ws = self.workspaces[str(node)] | 
 |         else: | 
 |             ws = self.workspace | 
 |         if do_copy: | 
 |             ws = copy(ws) | 
 |         self.workspace_ctx.append(ws) | 
 |         try: | 
 |             yield ws | 
 |         finally: | 
 |             del self.workspace_ctx[-1] | 
 |  | 
 |     def define_blob(self, blob): | 
 |         self.workspace[blob] += 1 | 
 |  | 
 |     def need_blob(self, blob): | 
 |         if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE): | 
 |             return | 
 |         assert blob in self.workspace, 'Blob undefined: %s' % blob | 
 |  | 
 |  | 
 | @Analyzer.register(OperatorDef) | 
 | def analyze_op(analyzer, op): | 
 |     for x in op.input: | 
 |         analyzer.need_blob(x) | 
 |     for x in op.output: | 
 |         analyzer.define_blob(x) | 
 |  | 
 |  | 
 | @Analyzer.register(Net) | 
 | def analyze_net(analyzer, net): | 
 |     for x in net.Proto().op: | 
 |         analyzer(x) | 
 |  | 
 |  | 
 | @Analyzer.register(ExecutionStep) | 
 | def analyze_step(analyzer, step): | 
 |     proto = step.Proto() | 
 |     with analyzer.set_workspace(do_copy=proto.create_workspace): | 
 |         if proto.report_net: | 
 |             with analyzer.set_workspace(do_copy=True): | 
 |                 analyzer(step.get_net(proto.report_net)) | 
 |         all_new_blobs = set() | 
 |         substeps = step.Substeps() + [step.get_net(n) for n in proto.network] | 
 |         for substep in substeps: | 
 |             with analyzer.set_workspace( | 
 |                     do_copy=proto.concurrent_substeps) as ws_in: | 
 |                 analyzer(substep) | 
 |                 if proto.should_stop_blob: | 
 |                     analyzer.need_blob(proto.should_stop_blob) | 
 |             if proto.concurrent_substeps: | 
 |                 new_blobs = set(ws_in.keys()) - set(analyzer.workspace.keys()) | 
 |                 assert len(all_new_blobs & new_blobs) == 0, ( | 
 |                     'Error: Blobs created by multiple parallel steps: %s' % ( | 
 |                         ', '.join(all_new_blobs & new_blobs))) | 
 |                 all_new_blobs |= new_blobs | 
 |     for x in all_new_blobs: | 
 |         analyzer.define_blob(x) | 
 |  | 
 |  | 
 | @Analyzer.register(Task) | 
 | def analyze_task(analyzer, task): | 
 |     # check that our plan protobuf is not too large (limit of 64Mb) | 
 |     step = task.get_step() | 
 |     plan = Plan(task.node) | 
 |     plan.AddStep(step) | 
 |     proto_len = len(plan.Proto().SerializeToString()) | 
 |     assert proto_len < 2 ** 26, ( | 
 |         'Due to a protobuf limitation, serialized tasks must be smaller ' | 
 |         'than 64Mb, but this task has {} bytes.' % proto_len) | 
 |  | 
 |     is_private = task.workspace_type() != WorkspaceType.GLOBAL | 
 |     with analyzer.set_workspace(do_copy=is_private): | 
 |         analyzer(step) | 
 |  | 
 |  | 
 | @Analyzer.register(TaskGroup) | 
 | def analyze_task_group(analyzer, tg): | 
 |     for task in tg.tasks_by_node().tasks(): | 
 |         with analyzer.set_workspace(node=task.node): | 
 |             analyzer(task) | 
 |  | 
 |  | 
 | @Analyzer.register(Job) | 
 | def analyze_job(analyzer, job): | 
 |     analyzer(job.init_group) | 
 |     analyzer(job.epoch_group) | 
 |  | 
 |  | 
 | def analyze(obj): | 
 |     """ | 
 |     Given a Job, visits all the execution steps making sure that: | 
 |       - no undefined blobs will be found during execution | 
 |       - no blob with same name is defined in concurrent steps | 
 |     """ | 
 |     Analyzer()(obj) | 
 |  | 
 |  | 
 | class Text: | 
 |     def __init__(self): | 
 |         self._indent = 0 | 
 |         self._lines_in_context = [0] | 
 |         self.lines = [] | 
 |  | 
 |     @contextmanager | 
 |     def context(self, text): | 
 |         if text is not None: | 
 |             self.add('with %s:' % text) | 
 |             self._indent += 4 | 
 |             self._lines_in_context.append(0) | 
 |         try: | 
 |             yield | 
 |         finally: | 
 |             if text is not None: | 
 |                 if self._lines_in_context[-1] == 0: | 
 |                     self.add('pass') | 
 |                 self._indent -= 4 | 
 |                 del self._lines_in_context[-1] | 
 |  | 
 |     def add(self, text): | 
 |         self._lines_in_context[-1] += 1 | 
 |         self.lines.append((' ' * self._indent) + text) | 
 |  | 
 |     def __str__(self): | 
 |         return '\n'.join(self.lines) | 
 |  | 
 |  | 
 | class Printer(Visitor, Text): | 
 |     def __init__(self, factor_prefixes=False, c2_syntax=True): | 
 |         super(Visitor, self).__init__() | 
 |         super(Text, self).__init__() | 
 |         self.factor_prefixes = factor_prefixes | 
 |         self.c2_syntax = c2_syntax | 
 |         self.c2_net_name = None | 
 |  | 
 |  | 
 | def _sanitize_str(s): | 
 |     if isinstance(s, str): | 
 |         sanitized = s | 
 |     elif isinstance(s, bytes): | 
 |         sanitized = s.decode('ascii', errors='ignore') | 
 |     else: | 
 |         sanitized = str(s) | 
 |     if len(sanitized) < 64: | 
 |         return "'%s'" % sanitized | 
 |     else: | 
 |         return "'%s'" % sanitized[:64] + '...<+len=%d>' % (len(sanitized) - 64) | 
 |  | 
 |  | 
 | def _arg_val(arg): | 
 |     if arg.HasField('f'): | 
 |         return str(arg.f) | 
 |     if arg.HasField('i'): | 
 |         return str(arg.i) | 
 |     if arg.HasField('s'): | 
 |         return _sanitize_str(arg.s) | 
 |     if arg.floats: | 
 |         return str(list(arg.floats)) | 
 |     if arg.ints: | 
 |         return str(list(arg.ints)) | 
 |     if arg.strings: | 
 |         return str([_sanitize_str(s) for s in arg.strings]) | 
 |     return '[]' | 
 |  | 
 |  | 
 | def commonprefix(m): | 
 |     "Given a list of strings, returns the longest common prefix" | 
 |     if not m: | 
 |         return '' | 
 |     s1 = min(m) | 
 |     s2 = max(m) | 
 |     for i, c in enumerate(s1): | 
 |         if c != s2[i]: | 
 |             return s1[:i] | 
 |     return s1 | 
 |  | 
 |  | 
 | def format_value(val): | 
 |     if isinstance(val, list): | 
 |         return '[%s]' % ', '.join("'%s'" % str(v) for v in val) | 
 |     else: | 
 |         return str(val) | 
 |  | 
 |  | 
 | def factor_prefix(vals, do_it): | 
 |     vals = [format_value(v) for v in vals] | 
 |     prefix = commonprefix(vals) if len(vals) > 1 and do_it else '' | 
 |     joined = ', '.join(v[len(prefix):] for v in vals) | 
 |     return '%s[%s]' % (prefix, joined) if prefix else joined | 
 |  | 
 |  | 
 | def call(op, inputs=None, outputs=None, factor_prefixes=False): | 
 |     if not inputs: | 
 |         inputs = '' | 
 |     else: | 
 |         inputs_v = [a for a in inputs if not isinstance(a, tuple)] | 
 |         inputs_kv = [a for a in inputs if isinstance(a, tuple)] | 
 |         inputs = ', '.join( | 
 |             x | 
 |             for x in chain( | 
 |                 [factor_prefix(inputs_v, factor_prefixes)], | 
 |                 ('%s=%s' % kv for kv in inputs_kv), | 
 |             ) | 
 |             if x | 
 |         ) | 
 |     call = '%s(%s)' % (op, inputs) | 
 |     return call if not outputs else '%s = %s' % ( | 
 |         factor_prefix(outputs, factor_prefixes), call) | 
 |  | 
 |  | 
 | def format_device_option(dev_opt): | 
 |     if not dev_opt or not ( | 
 |             dev_opt.device_type or dev_opt.device_id or dev_opt.node_name): | 
 |         return None | 
 |     return call( | 
 |         'DeviceOption', | 
 |         [dev_opt.device_type, dev_opt.device_id, "'%s'" % dev_opt.node_name]) | 
 |  | 
 |  | 
 | @Printer.register(OperatorDef) | 
 | def print_op(text, op): | 
 |     args = [(a.name, _arg_val(a)) for a in op.arg] | 
 |     dev_opt_txt = format_device_option(op.device_option) | 
 |     if dev_opt_txt: | 
 |         args.append(('device_option', dev_opt_txt)) | 
 |  | 
 |     if text.c2_net_name: | 
 |         text.add(call( | 
 |             text.c2_net_name + '.' + op.type, | 
 |             [list(op.input), list(op.output)] + args)) | 
 |     else: | 
 |         text.add(call( | 
 |             op.type, | 
 |             list(op.input) + args, | 
 |             op.output, | 
 |             factor_prefixes=text.factor_prefixes)) | 
 |     for arg in op.arg: | 
 |         if arg.HasField('n'): | 
 |             with text.context('arg: %s' % arg.name): | 
 |                 text(arg.n) | 
 |  | 
 |  | 
 | @Printer.register(NetDef) | 
 | def print_net_def(text, net_def): | 
 |     if text.c2_syntax: | 
 |         text.add(call('core.Net', ["'%s'" % net_def.name], [net_def.name])) | 
 |         text.c2_net_name = net_def.name | 
 |     else: | 
 |         text.add('# net: %s' % net_def.name) | 
 |     for op in net_def.op: | 
 |         text(op) | 
 |     if text.c2_syntax: | 
 |         text.c2_net_name = None | 
 |  | 
 |  | 
 | @Printer.register(Net) | 
 | def print_net(text, net): | 
 |     text(net.Proto()) | 
 |  | 
 |  | 
 | def _get_step_context(step): | 
 |     proto = step.Proto() | 
 |     if proto.should_stop_blob: | 
 |         return call('loop'), False | 
 |     if proto.num_iter and proto.num_iter != 1: | 
 |         return call('loop', [proto.num_iter]), False | 
 |     if proto.num_concurrent_instances > 1: | 
 |         return ( | 
 |             call('parallel', | 
 |                  [('num_instances', proto.num_concurrent_instances)]), | 
 |             len(step.Substeps()) > 1) | 
 |     concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1 | 
 |     if concurrent: | 
 |         return call('parallel'), True | 
 |     if proto.report_net: | 
 |         return call('run_once'), False | 
 |     return None, False | 
 |  | 
 |  | 
 | @Printer.register(ExecutionStep) | 
 | def print_step(text, step): | 
 |     proto = step.Proto() | 
 |     step_ctx, do_substep = _get_step_context(step) | 
 |     with text.context(step_ctx): | 
 |         if proto.report_net: | 
 |             with text.context(call('report_net', [proto.report_interval])): | 
 |                 text(step.get_net(proto.report_net)) | 
 |         substeps = step.Substeps() + [step.get_net(n) for n in proto.network] | 
 |         for substep in substeps: | 
 |             sub_proto = ( | 
 |                 substep.Proto() if isinstance(substep, ExecutionStep) else None) | 
 |             if sub_proto is not None and sub_proto.run_every_ms: | 
 |                 substep_ctx = call( | 
 |                     'reporter', | 
 |                     [str(substep), ('interval_ms', sub_proto.run_every_ms)]) | 
 |             elif do_substep: | 
 |                 title = ( | 
 |                     'workspace' | 
 |                     if sub_proto is not None and sub_proto.create_workspace else | 
 |                     'step') | 
 |                 substep_ctx = call(title, [str(substep)]) | 
 |             else: | 
 |                 substep_ctx = None | 
 |             with text.context(substep_ctx): | 
 |                 text(substep) | 
 |                 if proto.should_stop_blob: | 
 |                     text.add(call('yield stop_if', [proto.should_stop_blob])) | 
 |  | 
 |  | 
 | def _print_task_output(x): | 
 |     assert isinstance(x, TaskOutput) | 
 |     return 'Output[' + ', '.join(str(x) for x in x.names) + ']' | 
 |  | 
 |  | 
 | @Printer.register(Task) | 
 | def print_task(text, task): | 
 |     outs = ', '.join(_print_task_output(o) for o in task.outputs()) | 
 |     context = [('node', task.node), ('name', task.name), ('outputs', outs)] | 
 |     with text.context(call('Task', context)): | 
 |         text(task.get_step()) | 
 |  | 
 |  | 
 | @Printer.register(TaskGroup) | 
 | def print_task_group(text, tg, header=None): | 
 |     with text.context(header or call('TaskGroup')): | 
 |         for task in tg.tasks_by_node().tasks(): | 
 |             text(task) | 
 |  | 
 |  | 
 | @Printer.register(Job) | 
 | def print_job(text, job): | 
 |     text(job.init_group, 'Job.current().init_group') | 
 |     text(job.epoch_group, 'Job.current().epoch_group') | 
 |     with text.context('Job.current().stop_conditions'): | 
 |         for out in job.stop_conditions: | 
 |             text.add(_print_task_output(out)) | 
 |     text(job.download_group, 'Job.current().download_group') | 
 |     text(job.exit_group, 'Job.current().exit_group') | 
 |  | 
 |  | 
 | def to_string(obj, **kwargs): | 
 |     """ | 
 |     Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string | 
 |     with detailed description of the execution steps. | 
 |     """ | 
 |     printer = Printer(**kwargs) | 
 |     printer(obj) | 
 |     return str(printer) | 
 |  | 
 |  | 
 | def debug_net(net): | 
 |     """ | 
 |     Given a Net, produce another net that logs info about the operator call | 
 |     before each operator execution. Use for debugging purposes. | 
 |     """ | 
 |     assert isinstance(net, Net) | 
 |     debug_net = Net(str(net)) | 
 |     assert isinstance(net, Net) | 
 |     for op in net.Proto().op: | 
 |         text = Text() | 
 |         print_op(op, text) | 
 |         debug_net.LogInfo(str(text)) | 
 |         debug_net.Proto().op.extend([op]) | 
 |     return debug_net |