blob: 08361e6c548708c0bc79ba120b5cd9719d579cab [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, workspace
from caffe2.python.task import Task, TaskGroup, WorkspaceType
class Session(object):
"""
Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
A session can potentially run in multiple nodes concurrently.
Example:
from core import Net
from caffe2.python.task import Task, TaskGroup, WorkspaceType
net = Net('test1')
net.Add([net.Const(1), net.Const(2)])
net2 = net.Clone()
step = core.execution_step('step1', [net2])
with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
with Node('node1'):
n1setup = net.Net('n1setup')
n1msg = n1setup.Const('Hello from node 1.')
Task(step=n1setup)
with TaskGroup() as private_tg:
with Node('node1'):
n1 = net.Net('n1')
n1.Print(n1msg, 0)
Task(step=n1)
with Node('node2'):
n2 = net.Net('n2')
n2.Print(n2.Const('Hello from node 2.'), 0)
Task(step=n2)
session = LocalSession()
session.run(net)
session.run(step)
session.run(init_tg)
session.run(private_tg)
Global Workspace:
At the beggining of the session, a global workspace is created and kept
alive for the duration of the session.
Private Workspace:
Tasks can be run either directly on the global workspace, or they can
instantiate a private child workspace that is released after each run.
Blob visibility:
Tasks running in different nodes in parallel will always run under
different workspaces, so it must be assumed that they won't be able to
access each other's blobs. On the other hand, tasks running on the same
node are guaranteed to run on the same workspace within a run.
"""
def __init__(self):
self._open = True
self._runnable_cache = {}
def is_open(self):
return self._open
def run(self, runnable):
assert self.is_open(), 'Session is closed.'
if runnable not in self._runnable_cache:
if isinstance(runnable, TaskGroup):
tg = runnable
else:
tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
if isinstance(runnable, Task):
tg.add(runnable)
elif isinstance(runnable, core.ExecutionStep):
tg.add(Task(step=runnable))
else:
step = core.execution_step('runnable', runnable)
tg.add(Task(step=step))
self._runnable_cache[runnable] = tg
self._run_task_group(self._runnable_cache[runnable])
def close(self):
if self.is_open():
self._do_close()
self._open = False
def fetch_output(self, output):
raise NotImplementedError()
def _run_task_group(self, task_group):
raise NotImplementedError()
def _do_close(self):
pass
def __enter__(self):
assert self._open, 'Session already closed.'
return self
def __exit__(self, ex_type, value, traceback):
if ex_type is None:
self.close()
class LocalSession(Session):
"""
Session that runs in a single node.
Tasks are all remapped to run in parallel in the 'local' node.
Currently, LocalSession runs all parallel tasks in the same workspace,
but this behavior may change in the future. Only tasks pointing to the
same logical node are guaranteed to always run in the same workspace.
"""
def __init__(self, ws=None):
Session.__init__(self)
self._ws = ws or workspace.C.Workspace()
self._plan_caches = {}
def _run_task_group(self, task_group):
if task_group not in self._plan_caches:
task = task_group.to_task()
plan = core.Plan('task_group_plan')
plan.AddStep(task.get_step())
self._plan_caches[task_group] = (plan, task)
plan, task = self._plan_caches[task_group]
# make sure the output blobs belong to the parent workspace
outputs = []
for name in task.output_names():
self._ws.create_blob(str(name))
outputs.append(core.BlobReference(str(name)))
task.set_outputs(outputs, _fetch_func=self._fetch_output)
task_ws = (
workspace.C.Workspace(self._ws)
if task.workspace_type == WorkspaceType.PRIVATE else self._ws)
with workspace.WorkspaceGuard(task_ws):
task_ws.run(plan)
def _fetch_output(self, output):
return self._ws.blobs[str(output)].fetch()