Caffe2 - Python API
A deep learning, cross platform ML framework
session.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package session
17 # Module caffe2.python.session
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 
24 from caffe2.python import core, workspace
25 from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
26 
27 
28 class CompiledRunnable(object):
29  """ Wrapper for compiled runnable returned from session.compile() """
30  def __init__(self, obj, session_class):
31  self.obj = obj
32  self.session_class = session_class
33 
34 
35 class Session(object):
36  """
37  Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
38  A session can potentially run in multiple nodes concurrently.
39 
40 
41  Example:
42  from core import Net
43  from caffe2.python.task import Task, TaskGroup, WorkspaceType
44 
45  net = Net('test1')
46  net.Add([net.Const(1), net.Const(2)])
47 
48  net2 = net.Clone()
49  step = core.execution_step('step1', [net2])
50 
51  with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
52  with Node('node1'):
53  n1setup = net.Net('n1setup')
54  n1msg = n1setup.Const('Hello from node 1.')
55  Task(step=n1setup)
56 
57  with TaskGroup() as private_tg:
58  with Node('node1'):
59  n1 = net.Net('n1')
60  n1.Print(n1msg, 0)
61  Task(step=n1)
62  with Node('node2'):
63  n2 = net.Net('n2')
64  n2.Print(n2.Const('Hello from node 2.'), 0)
65  Task(step=n2)
66 
67  session = LocalSession()
68  session.run(net)
69  session.run(step)
70  session.run(init_tg)
71  session.run(private_tg)
72 
73 
74  Global Workspace:
75  At the beggining of the session, a global workspace is created and kept
76  alive for the duration of the session.
77 
78 
79  Private Workspace:
80  Tasks can be run either directly on the global workspace, or they can
81  instantiate a private child workspace that is released after each run.
82 
83  Blob visibility:
84  Tasks running in different nodes in parallel will always run under
85  different workspaces, so it must be assumed that they won't be able to
86  access each other's blobs. Tasks running on the same node will follow
87  Workspace hierarchy rules: tasks running on separate private workspaces
88  will only be able to share blobs defined on a common parent Workspace.
89  """
90 
91  _compiled_cache = {}
92 
93  def __init__(self):
94  self._open = True
95 
96  def is_open(self):
97  return self._open
98 
99  @classmethod
100  def compile(cls, runnable, workspace_type=None, setup_net_list=None):
101  if isinstance(runnable, CompiledRunnable):
102  assert cls == runnable.session_class, (
103  'Runnable was compiled for different session type. ' +
104  'Need: %s, got: %s' % (
105  cls.__name__, runnable.session_class.__name__))
106  return runnable
107 
108  if runnable in cls._compiled_cache:
109  return cls._compiled_cache[runnable]
110 
111  if isinstance(runnable, TaskGroup):
112  if workspace_type:
113  if runnable.workspace_type():
114  assert runnable.workspace_type() == workspace_type, \
115  "Require {} but already have {}".format(
116  workspace_type, runnable.workspace_type())
117  else:
118  runnable._workspace_type = workspace_type
119  tg = runnable
120  else:
121  if workspace_type is None:
122  workspace_type = WorkspaceType.GLOBAL
123  tg = TaskGroup(workspace_type=workspace_type)
124  if isinstance(runnable, Task):
125  tg.add(runnable)
126  elif isinstance(runnable, core.ExecutionStep):
127  tg.add(Task(step=runnable))
128  elif isinstance(runnable, core.Plan):
129  # ExecutionSteps in Plan() object is supposed to run sequentially, while
130  # tasks in TaskGroup run in parallel. So if we have multiple
131  # ExecutionSteps in Plan() object, we choose to have a root
132  # ExecutionStep to wrap all ExecutionSteps.
133  assert len(runnable.Steps()) > 0
134  if len(runnable.Steps()) == 1:
135  tg.add(Task(step=runnable.Steps()[0]))
136  else:
137  # Task takes a list of ExecutionSteps and automatically wrap into
138  # a root ExecutionStep
139  tg.add(Task(step=runnable.Steps()))
140  else:
141  step = core.execution_step('runnable', runnable)
142  tg.add(Task(step=step))
143  compiled = CompiledRunnable(
144  cls._compile_task_group(tg, setup_net_list), session_class=cls)
145  cls._compiled_cache[runnable] = compiled
146  return compiled
147 
148  def run(self, runnable, workspace_type=None, setup_net_list=None):
149  """Run the given runnable.
150 
151  Args:
152  runnable: Object recognized by the Session. Currently, we support
153  TaskGroup, Task, Plan, ExecutionStep, and Net.
154  workspace_type: A string defined in the WorkspaceType object.
155  setup_net_list: A list of Net objects or a list of NetDef protos.
156  So far this is only used by the DistributedSession, in which we
157  need to pass a list of special nets to setup the master.
158  """
159  assert self.is_open(), 'Session is closed.'
160  assert runnable is not None, 'Got a none runnable.'
161  self._run_compiled(self.compile(runnable, workspace_type,
162  setup_net_list).obj)
163 
164  def close(self):
165  if self.is_open():
166  self._do_close()
167  self._open = False
168 
169  def fetch_output(self, output):
170  raise NotImplementedError()
171 
172  def _run_compiled(self, task_group):
173  raise NotImplementedError()
174 
175  @classmethod
176  def _compile_task_group(cls, task_group, setup_net_list=None):
177  return task_group
178 
179  def _do_close(self):
180  pass
181 
182  def __enter__(self):
183  assert self._open, 'Session already closed.'
184  return self
185 
186  def __exit__(self, ex_type, value, traceback):
187  if ex_type is None:
188  self.close()
189 
190 
192  """
193  Session that runs in a single node.
194  Tasks are all remapped to run in parallel in the 'local' node.
195 
196  Currently, LocalSession runs all parallel tasks in the same workspace,
197  but this behavior may change in the future. Only tasks pointing to the
198  same logical node are guaranteed to always run in the same workspace.
199  """
200  def __init__(self, ws=None):
201  Session.__init__(self)
202  self._ws = ws or workspace.C.Workspace.current
203 
204  @classmethod
205  def _compile_task_group(cls, task_group, setup_net_list=None):
206  with Cluster():
207  task = task_group.to_task()
208  plan = core.Plan('task_group_plan')
209  plan.AddStep(task.get_step())
210  return (plan, task.output_list(), task.workspace_type)
211 
212  def _run_compiled(self, compiled):
213  plan, output_list, workspace_type = compiled
214 
215  # make sure the output blobs belong to the parent workspace
216  outputs = []
217  for name in output_list.names():
218  self._ws.create_blob(str(name))
219  outputs.append(core.BlobReference(str(name)))
220  output_list.set_values(outputs, _fetch_func=self._fetch_output)
221  task_ws = (
222  workspace.C.Workspace(self._ws)
223  if workspace_type == WorkspaceType.PRIVATE else self._ws)
224  with workspace.WorkspaceGuard(task_ws):
225  task_ws.run(plan)
226 
227  def _fetch_output(self, output):
228  return self._ws.blobs[str(output)].fetch()
def run(self, runnable, workspace_type=None, setup_net_list=None)
Definition: session.py:148
def _fetch_output(self, output)
Definition: session.py:227
def _compile_task_group(cls, task_group, setup_net_list=None)
Definition: session.py:176
def compile(cls, runnable, workspace_type=None, setup_net_list=None)
Definition: session.py:100
def _run_compiled(self, task_group)
Definition: session.py:172