|  | ## @package session | 
|  | # Module caffe2.python.session | 
|  | from __future__ import absolute_import | 
|  | from __future__ import division | 
|  | from __future__ import print_function | 
|  | from __future__ import unicode_literals | 
|  |  | 
|  |  | 
|  | from caffe2.python import core, workspace | 
|  | from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType | 
|  |  | 
|  |  | 
|  | class CompiledRunnable(object): | 
|  | """ Wrapper for compiled runnable returned from session.compile() """ | 
|  | def __init__(self, obj, session_class): | 
|  | self.obj = obj | 
|  | self.session_class = session_class | 
|  |  | 
|  |  | 
|  | class Session(object): | 
|  | """ | 
|  | Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups. | 
|  | A session can potentially run in multiple nodes concurrently. | 
|  |  | 
|  |  | 
|  | Example: | 
|  | from core import Net | 
|  | from caffe2.python.task import Task, TaskGroup, WorkspaceType | 
|  |  | 
|  | net = Net('test1') | 
|  | net.Add([net.Const(1), net.Const(2)]) | 
|  |  | 
|  | net2 = net.Clone() | 
|  | step = core.execution_step('step1', [net2]) | 
|  |  | 
|  | with TaskGroup(WorkspaceType.GLOBAL) as init_tg: | 
|  | with Node('node1'): | 
|  | n1setup = net.Net('n1setup') | 
|  | n1msg = n1setup.Const('Hello from node 1.') | 
|  | Task(step=n1setup) | 
|  |  | 
|  | with TaskGroup() as private_tg: | 
|  | with Node('node1'): | 
|  | n1 = net.Net('n1') | 
|  | n1.Print(n1msg, 0) | 
|  | Task(step=n1) | 
|  | with Node('node2'): | 
|  | n2 = net.Net('n2') | 
|  | n2.Print(n2.Const('Hello from node 2.'), 0) | 
|  | Task(step=n2) | 
|  |  | 
|  | session = LocalSession() | 
|  | session.run(net) | 
|  | session.run(step) | 
|  | session.run(init_tg) | 
|  | session.run(private_tg) | 
|  |  | 
|  |  | 
|  | Global Workspace: | 
|  | At the beggining of the session, a global workspace is created and kept | 
|  | alive for the duration of the session. | 
|  |  | 
|  |  | 
|  | Private Workspace: | 
|  | Tasks can be run either directly on the global workspace, or they can | 
|  | instantiate a private child workspace that is released after each run. | 
|  |  | 
|  | Blob visibility: | 
|  | Tasks running in different nodes in parallel will always run under | 
|  | different workspaces, so it must be assumed that they won't be able to | 
|  | access each other's blobs. Tasks running on the same node will follow | 
|  | Workspace hierarchy rules: tasks running on separate private workspaces | 
|  | will only be able to share blobs defined on a common parent Workspace. | 
|  | """ | 
|  |  | 
|  | _compiled_cache = {} | 
|  |  | 
|  | def __init__(self): | 
|  | self._open = True | 
|  |  | 
|  | def is_open(self): | 
|  | return self._open | 
|  |  | 
|  | @classmethod | 
|  | def compile(cls, runnable, workspace_type=None, setup_net_list=None): | 
|  | if isinstance(runnable, CompiledRunnable): | 
|  | assert cls == runnable.session_class, ( | 
|  | 'Runnable was compiled for different session type. ' + | 
|  | 'Need: %s, got: %s' % ( | 
|  | cls.__name__, runnable.session_class.__name__)) | 
|  | return runnable | 
|  |  | 
|  | if runnable in cls._compiled_cache: | 
|  | return cls._compiled_cache[runnable] | 
|  |  | 
|  | if isinstance(runnable, TaskGroup): | 
|  | if workspace_type: | 
|  | if runnable.workspace_type(): | 
|  | assert runnable.workspace_type() == workspace_type, \ | 
|  | "Require {} but already have {}".format( | 
|  | workspace_type, runnable.workspace_type()) | 
|  | else: | 
|  | runnable._workspace_type = workspace_type | 
|  | tg = runnable | 
|  | else: | 
|  | if workspace_type is None: | 
|  | workspace_type = WorkspaceType.GLOBAL | 
|  | tg = TaskGroup(workspace_type=workspace_type) | 
|  | if isinstance(runnable, Task): | 
|  | tg.add(runnable) | 
|  | elif isinstance(runnable, core.ExecutionStep): | 
|  | tg.add(Task(step=runnable)) | 
|  | elif isinstance(runnable, core.Plan): | 
|  | # ExecutionSteps in Plan() object is supposed to run sequentially, while | 
|  | # tasks in TaskGroup run in parallel. So if we have multiple | 
|  | # ExecutionSteps in Plan() object, we choose to have a root | 
|  | # ExecutionStep to wrap all ExecutionSteps. | 
|  | assert len(runnable.Steps()) > 0 | 
|  | if len(runnable.Steps()) == 1: | 
|  | tg.add(Task(step=runnable.Steps()[0])) | 
|  | else: | 
|  | # Task takes a list of ExecutionSteps and automatically wrap into | 
|  | # a root ExecutionStep | 
|  | tg.add(Task(step=runnable.Steps())) | 
|  | else: | 
|  | step = core.execution_step('runnable', runnable) | 
|  | tg.add(Task(step=step)) | 
|  | compiled = CompiledRunnable( | 
|  | cls._compile_task_group(tg, setup_net_list), session_class=cls) | 
|  | cls._compiled_cache[runnable] = compiled | 
|  | return compiled | 
|  |  | 
|  | def run(self, runnable, workspace_type=None, setup_net_list=None): | 
|  | """Run the given runnable. | 
|  |  | 
|  | Args: | 
|  | runnable: Object recognized by the Session. Currently, we support | 
|  | TaskGroup, Task, Plan, ExecutionStep, and Net. | 
|  | workspace_type: A string defined in the WorkspaceType object. | 
|  | setup_net_list: A list of Net objects or a list of NetDef protos. | 
|  | So far this is only used by the DistributedSession, in which we | 
|  | need to pass a list of special nets to setup the master. | 
|  | """ | 
|  | assert self.is_open(), 'Session is closed.' | 
|  | assert runnable is not None, 'Got a none runnable.' | 
|  | self._run_compiled(self.compile(runnable, workspace_type, | 
|  | setup_net_list).obj) | 
|  |  | 
|  | def close(self): | 
|  | if self.is_open(): | 
|  | self._do_close() | 
|  | self._open = False | 
|  |  | 
|  | def fetch_output(self, output): | 
|  | raise NotImplementedError() | 
|  |  | 
|  | def _run_compiled(self, task_group): | 
|  | raise NotImplementedError() | 
|  |  | 
|  | @classmethod | 
|  | def _compile_task_group(cls, task_group, setup_net_list=None): | 
|  | return task_group | 
|  |  | 
|  | def _do_close(self): | 
|  | pass | 
|  |  | 
|  | def __enter__(self): | 
|  | assert self._open, 'Session already closed.' | 
|  | return self | 
|  |  | 
|  | def __exit__(self, ex_type, value, traceback): | 
|  | if ex_type is None: | 
|  | self.close() | 
|  |  | 
|  |  | 
|  | class LocalSession(Session): | 
|  | """ | 
|  | Session that runs in a single node. | 
|  | Tasks are all remapped to run in parallel in the 'local' node. | 
|  |  | 
|  | Currently, LocalSession runs all parallel tasks in the same workspace, | 
|  | but this behavior may change in the future. Only tasks pointing to the | 
|  | same logical node are guaranteed to always run in the same workspace. | 
|  | """ | 
|  | def __init__(self, ws=None): | 
|  | Session.__init__(self) | 
|  | self._ws = ws or workspace.C.Workspace.current | 
|  |  | 
|  | @classmethod | 
|  | def _compile_task_group(cls, task_group, setup_net_list=None): | 
|  | with Cluster(): | 
|  | task = task_group.to_task() | 
|  | plan = core.Plan('task_group_plan') | 
|  | plan.AddStep(task.get_step()) | 
|  | return (plan, task.output_list(), task.workspace_type) | 
|  |  | 
|  | def _run_compiled(self, compiled): | 
|  | plan, output_list, workspace_type = compiled | 
|  |  | 
|  | # make sure the output blobs belong to the parent workspace | 
|  | outputs = [] | 
|  | for name in output_list.names(): | 
|  | self._ws.create_blob(str(name)) | 
|  | outputs.append(core.BlobReference(str(name))) | 
|  | output_list.set_values(outputs, _fetch_func=self._fetch_output) | 
|  | task_ws = ( | 
|  | workspace.C.Workspace(self._ws) | 
|  | if workspace_type == WorkspaceType.PRIVATE else self._ws) | 
|  | with workspace.WorkspaceGuard(task_ws): | 
|  | task_ws.run(plan) | 
|  |  | 
|  | def _fetch_output(self, output): | 
|  | return self._ws.blobs[str(output)].fetch() |