Caffe2 - Python API
A deep learning, cross platform ML framework
session.py
1 ## @package session
2 # Module caffe2.python.session
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 from caffe2.python import core, workspace
10 from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
11 
12 
13 class CompiledRunnable(object):
14  """ Wrapper for compiled runnable returned from session.compile() """
15  def __init__(self, obj, session_class):
16  self.obj = obj
17  self.session_class = session_class
18 
19 
20 class Session(object):
21  """
22  Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
23  A session can potentially run in multiple nodes concurrently.
24 
25 
26  Example:
27  from core import Net
28  from caffe2.python.task import Task, TaskGroup, WorkspaceType
29 
30  net = Net('test1')
31  net.Add([net.Const(1), net.Const(2)])
32 
33  net2 = net.Clone()
34  step = core.execution_step('step1', [net2])
35 
36  with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
37  with Node('node1'):
38  n1setup = net.Net('n1setup')
39  n1msg = n1setup.Const('Hello from node 1.')
40  Task(step=n1setup)
41 
42  with TaskGroup() as private_tg:
43  with Node('node1'):
44  n1 = net.Net('n1')
45  n1.Print(n1msg, 0)
46  Task(step=n1)
47  with Node('node2'):
48  n2 = net.Net('n2')
49  n2.Print(n2.Const('Hello from node 2.'), 0)
50  Task(step=n2)
51 
52  session = LocalSession()
53  session.run(net)
54  session.run(step)
55  session.run(init_tg)
56  session.run(private_tg)
57 
58 
59  Global Workspace:
60  At the beggining of the session, a global workspace is created and kept
61  alive for the duration of the session.
62 
63 
64  Private Workspace:
65  Tasks can be run either directly on the global workspace, or they can
66  instantiate a private child workspace that is released after each run.
67 
68  Blob visibility:
69  Tasks running in different nodes in parallel will always run under
70  different workspaces, so it must be assumed that they won't be able to
71  access each other's blobs. Tasks running on the same node will follow
72  Workspace hierarchy rules: tasks running on separate private workspaces
73  will only be able to share blobs defined on a common parent Workspace.
74  """
75 
76  _compiled_cache = {}
77 
78  def __init__(self):
79  self._open = True
80 
81  def is_open(self):
82  return self._open
83 
84  @classmethod
85  def compile(cls, runnable, workspace_type=None, setup_net_list=None):
86  if isinstance(runnable, CompiledRunnable):
87  assert cls == runnable.session_class, (
88  'Runnable was compiled for different session type. ' +
89  'Need: %s, got: %s' % (
90  cls.__name__, runnable.session_class.__name__))
91  return runnable
92 
93  if runnable in cls._compiled_cache:
94  return cls._compiled_cache[runnable]
95 
96  if isinstance(runnable, TaskGroup):
97  if workspace_type:
98  if runnable.workspace_type():
99  assert runnable.workspace_type() == workspace_type, \
100  "Require {} but already have {}".format(
101  workspace_type, runnable.workspace_type())
102  else:
103  runnable._workspace_type = workspace_type
104  tg = runnable
105  else:
106  if workspace_type is None:
107  workspace_type = WorkspaceType.GLOBAL
108  tg = TaskGroup(workspace_type=workspace_type)
109  if isinstance(runnable, Task):
110  tg.add(runnable)
111  elif isinstance(runnable, core.ExecutionStep):
112  tg.add(Task(step=runnable))
113  elif isinstance(runnable, core.Plan):
114  # ExecutionSteps in Plan() object is supposed to run sequentially, while
115  # tasks in TaskGroup run in parallel. So if we have multiple
116  # ExecutionSteps in Plan() object, we choose to have a root
117  # ExecutionStep to wrap all ExecutionSteps.
118  assert len(runnable.Steps()) > 0
119  if len(runnable.Steps()) == 1:
120  tg.add(Task(step=runnable.Steps()[0]))
121  else:
122  # Task takes a list of ExecutionSteps and automatically wrap into
123  # a root ExecutionStep
124  tg.add(Task(step=runnable.Steps()))
125  else:
126  step = core.execution_step('runnable', runnable)
127  tg.add(Task(step=step))
128  compiled = CompiledRunnable(
129  cls._compile_task_group(tg, setup_net_list), session_class=cls)
130  cls._compiled_cache[runnable] = compiled
131  return compiled
132 
133  def run(self, runnable, workspace_type=None, setup_net_list=None):
134  """Run the given runnable.
135 
136  Args:
137  runnable: Object recognized by the Session. Currently, we support
138  TaskGroup, Task, Plan, ExecutionStep, and Net.
139  workspace_type: A string defined in the WorkspaceType object.
140  setup_net_list: A list of Net objects or a list of NetDef protos.
141  So far this is only used by the DistributedSession, in which we
142  need to pass a list of special nets to setup the master.
143  """
144  assert self.is_open(), 'Session is closed.'
145  assert runnable is not None, 'Got a none runnable.'
146  self._run_compiled(self.compile(runnable, workspace_type,
147  setup_net_list).obj)
148 
149  def close(self):
150  if self.is_open():
151  self._do_close()
152  self._open = False
153 
154  def fetch_output(self, output):
155  raise NotImplementedError()
156 
157  def _run_compiled(self, task_group):
158  raise NotImplementedError()
159 
160  @classmethod
161  def _compile_task_group(cls, task_group, setup_net_list=None):
162  return task_group
163 
164  def _do_close(self):
165  pass
166 
167  def __enter__(self):
168  assert self._open, 'Session already closed.'
169  return self
170 
171  def __exit__(self, ex_type, value, traceback):
172  if ex_type is None:
173  self.close()
174 
175 
177  """
178  Session that runs in a single node.
179  Tasks are all remapped to run in parallel in the 'local' node.
180 
181  Currently, LocalSession runs all parallel tasks in the same workspace,
182  but this behavior may change in the future. Only tasks pointing to the
183  same logical node are guaranteed to always run in the same workspace.
184  """
185  def __init__(self, ws=None):
186  Session.__init__(self)
187  self._ws = ws or workspace.C.Workspace.current
188 
189  @classmethod
190  def _compile_task_group(cls, task_group, setup_net_list=None):
191  with Cluster():
192  task = task_group.to_task()
193  plan = core.Plan('task_group_plan')
194  plan.AddStep(task.get_step())
195  return (plan, task.output_list(), task.workspace_type)
196 
197  def _run_compiled(self, compiled):
198  plan, output_list, workspace_type = compiled
199 
200  # make sure the output blobs belong to the parent workspace
201  outputs = []
202  for name in output_list.names():
203  self._ws.create_blob(str(name))
204  outputs.append(core.BlobReference(str(name)))
205  output_list.set_values(outputs, _fetch_func=self._fetch_output)
206  task_ws = (
207  workspace.C.Workspace(self._ws)
208  if workspace_type == WorkspaceType.PRIVATE else self._ws)
209  with workspace.WorkspaceGuard(task_ws):
210  task_ws.run(plan)
211 
212  def _fetch_output(self, output):
213  return self._ws.blobs[str(output)].fetch()
def run(self, runnable, workspace_type=None, setup_net_list=None)
Definition: session.py:133
def _fetch_output(self, output)
Definition: session.py:212
def _compile_task_group(cls, task_group, setup_net_list=None)
Definition: session.py:161
def compile(cls, runnable, workspace_type=None, setup_net_list=None)
Definition: session.py:85
def _run_compiled(self, task_group)
Definition: session.py:157