3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 """ Wrapper for compiled runnable returned from session.compile() """ 15 def __init__(self, obj, session_class):
22 Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups. 23 A session can potentially run in multiple nodes concurrently. 28 from caffe2.python.task import Task, TaskGroup, WorkspaceType 31 net.Add([net.Const(1), net.Const(2)]) 34 step = core.execution_step('step1', [net2]) 36 with TaskGroup(WorkspaceType.GLOBAL) as init_tg: 38 n1setup = net.Net('n1setup') 39 n1msg = n1setup.Const('Hello from node 1.') 42 with TaskGroup() as private_tg: 49 n2.Print(n2.Const('Hello from node 2.'), 0) 52 session = LocalSession() 56 session.run(private_tg) 60 At the beggining of the session, a global workspace is created and kept 61 alive for the duration of the session. 65 Tasks can be run either directly on the global workspace, or they can 66 instantiate a private child workspace that is released after each run. 69 Tasks running in different nodes in parallel will always run under 70 different workspaces, so it must be assumed that they won't be able to 71 access each other's blobs. Tasks running on the same node will follow 72 Workspace hierarchy rules: tasks running on separate private workspaces 73 will only be able to share blobs defined on a common parent Workspace. 85 def compile(cls, runnable, workspace_type=None, setup_net_list=None):
86 if isinstance(runnable, CompiledRunnable):
87 assert cls == runnable.session_class, (
88 'Runnable was compiled for different session type. ' +
89 'Need: %s, got: %s' % (
90 cls.__name__, runnable.session_class.__name__))
96 if isinstance(runnable, TaskGroup):
98 if runnable.workspace_type():
99 assert runnable.workspace_type() == workspace_type, \
100 "Require {} but already have {}".format(
101 workspace_type, runnable.workspace_type())
103 runnable._workspace_type = workspace_type
106 if workspace_type
is None:
107 workspace_type = WorkspaceType.GLOBAL
108 tg =
TaskGroup(workspace_type=workspace_type)
109 if isinstance(runnable, Task):
112 tg.add(
Task(step=runnable))
118 assert len(runnable.Steps()) > 0
119 if len(runnable.Steps()) == 1:
120 tg.add(
Task(step=runnable.Steps()[0]))
124 tg.add(
Task(step=runnable.Steps()))
126 step = core.execution_step(
'runnable', runnable)
127 tg.add(
Task(step=step))
133 def run(self, runnable, workspace_type=None, setup_net_list=None):
134 """Run the given runnable. 137 runnable: Object recognized by the Session. Currently, we support 138 TaskGroup, Task, Plan, ExecutionStep, and Net. 139 workspace_type: A string defined in the WorkspaceType object. 140 setup_net_list: A list of Net objects or a list of NetDef protos. 141 So far this is only used by the DistributedSession, in which we 142 need to pass a list of special nets to setup the master. 144 assert self.
is_open(),
'Session is closed.' 145 assert runnable
is not None,
'Got a none runnable.' 154 def fetch_output(self, output):
155 raise NotImplementedError()
157 def _run_compiled(self, task_group):
158 raise NotImplementedError()
161 def _compile_task_group(cls, task_group, setup_net_list=None):
168 assert self._open,
'Session already closed.' 171 def __exit__(self, ex_type, value, traceback):
178 Session that runs in a single node. 179 Tasks are all remapped to run in parallel in the 'local' node. 181 Currently, LocalSession runs all parallel tasks in the same workspace, 182 but this behavior may change in the future. Only tasks pointing to the 183 same logical node are guaranteed to always run in the same workspace. 185 def __init__(self, ws=None):
186 Session.__init__(self)
187 self.
_ws = ws
or workspace.C.Workspace.current
190 def _compile_task_group(cls, task_group, setup_net_list=None):
192 task = task_group.to_task()
194 plan.AddStep(task.get_step())
195 return (plan, task.output_list(), task.workspace_type)
197 def _run_compiled(self, compiled):
198 plan, output_list, workspace_type = compiled
202 for name
in output_list.names():
203 self._ws.create_blob(str(name))
205 output_list.set_values(outputs, _fetch_func=self.
_fetch_output)
207 workspace.C.Workspace(self.
_ws)
208 if workspace_type == WorkspaceType.PRIVATE
else self.
_ws)
209 with workspace.WorkspaceGuard(task_ws):
212 def _fetch_output(self, output):
213 return self._ws.blobs[str(output)].fetch()
def run(self, runnable, workspace_type=None, setup_net_list=None)
def _fetch_output(self, output)
def _compile_task_group(cls, task_group, setup_net_list=None)
def compile(cls, runnable, workspace_type=None, setup_net_list=None)
dictionary _compiled_cache
def _run_compiled(self, task_group)