3 from __future__ 
import absolute_import
     4 from __future__ 
import division
     5 from __future__ 
import print_function
     6 from __future__ 
import unicode_literals
    14     """ Wrapper for compiled runnable returned from session.compile() """    15     def __init__(self, obj, session_class):
    22     Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.    23     A session can potentially run in multiple nodes concurrently.    28         from caffe2.python.task import Task, TaskGroup, WorkspaceType    31         net.Add([net.Const(1), net.Const(2)])    34         step = core.execution_step('step1', [net2])    36         with TaskGroup(WorkspaceType.GLOBAL) as init_tg:    38                 n1setup = net.Net('n1setup')    39                 n1msg = n1setup.Const('Hello from node 1.')    42         with TaskGroup() as private_tg:    49                 n2.Print(n2.Const('Hello from node 2.'), 0)    52         session = LocalSession()    56         session.run(private_tg)    60         At the beggining of the session, a global workspace is created and kept    61         alive for the duration of the session.    65         Tasks can be run either directly on the global workspace, or they can    66         instantiate a private child workspace that is released after each run.    69         Tasks running in different nodes in parallel will always run under    70         different workspaces, so it must be assumed that they won't be able to    71         access each other's blobs. Tasks running on the same node will follow    72         Workspace hierarchy rules: tasks running on separate private workspaces    73         will only be able to share blobs defined on a common parent Workspace.    85     def compile(cls, runnable, workspace_type=None, setup_net_list=None):
    86         if isinstance(runnable, CompiledRunnable):
    87             assert cls == runnable.session_class, (
    88                 'Runnable was compiled for different session type. ' +
    89                 'Need: %s, got: %s' % (
    90                     cls.__name__, runnable.session_class.__name__))
    96         if isinstance(runnable, TaskGroup):
    98                 if runnable.workspace_type():
    99                     assert runnable.workspace_type() == workspace_type, \
   100                         "Require {} but already have {}".format(
   101                             workspace_type, runnable.workspace_type())
   103                     runnable._workspace_type = workspace_type
   106             if workspace_type 
is None:
   107                 workspace_type = WorkspaceType.GLOBAL
   108             tg = 
TaskGroup(workspace_type=workspace_type)
   109             if isinstance(runnable, Task):
   112                 tg.add(
Task(step=runnable))
   118                 assert len(runnable.Steps()) > 0
   119                 if len(runnable.Steps()) == 1:
   120                     tg.add(
Task(step=runnable.Steps()[0]))
   124                     tg.add(
Task(step=runnable.Steps()))
   126                 step = core.execution_step(
'runnable', runnable)
   127                 tg.add(
Task(step=step))
   133     def run(self, runnable, workspace_type=None, setup_net_list=None):
   134         """Run the given runnable.   137             runnable: Object recognized by the Session. Currently, we support   138                 TaskGroup, Task, Plan, ExecutionStep, and Net.   139             workspace_type: A string defined in the WorkspaceType object.   140             setup_net_list: A list of Net objects or a list of NetDef protos.   141                 So far this is only used by the DistributedSession, in which we   142                 need to pass a list of special nets to setup the master.   144         assert self.
is_open(), 
'Session is closed.'   145         assert runnable 
is not None, 
'Got a none runnable.'   154     def fetch_output(self, output):
   155         raise NotImplementedError()
   157     def _run_compiled(self, task_group):
   158         raise NotImplementedError()
   161     def _compile_task_group(cls, task_group, setup_net_list=None):
   168         assert self._open, 
'Session already closed.'   171     def __exit__(self, ex_type, value, traceback):
   178     Session that runs in a single node.   179     Tasks are all remapped to run in parallel in the 'local' node.   181     Currently, LocalSession runs all parallel tasks in the same workspace,   182     but this behavior may change in the future. Only tasks pointing to the   183     same logical node are guaranteed to always run in the same workspace.   185     def __init__(self, ws=None):
   186         Session.__init__(self)
   187         self.
_ws = ws 
or workspace.C.Workspace.current
   190     def _compile_task_group(cls, task_group, setup_net_list=None):
   192             task = task_group.to_task()
   194         plan.AddStep(task.get_step())
   195         return (plan, task.output_list(), task.workspace_type)
   197     def _run_compiled(self, compiled):
   198         plan, output_list, workspace_type = compiled
   202         for name 
in output_list.names():
   203             self._ws.create_blob(str(name))
   205         output_list.set_values(outputs, _fetch_func=self.
_fetch_output)
   207             workspace.C.Workspace(self.
_ws)
   208             if workspace_type == WorkspaceType.PRIVATE 
else self.
_ws)
   209         with workspace.WorkspaceGuard(task_ws):
   212     def _fetch_output(self, output):
   213         return self._ws.blobs[str(output)].fetch()
 
def run(self, runnable, workspace_type=None, setup_net_list=None)
 
def _fetch_output(self, output)
 
def _compile_task_group(cls, task_group, setup_net_list=None)
 
def compile(cls, runnable, workspace_type=None, setup_net_list=None)
 
dictionary _compiled_cache
 
def _run_compiled(self, task_group)