Caffe2 - Python API
A deep learning, cross platform ML framework
scope.py
1 ## @package scope
2 # Module caffe2.python.scope
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import contextlib
9 import threading
10 from past.builtins import basestring
11 
12 from caffe2.proto import caffe2_pb2
13 
14 
15 # The name scope and device scope when creating a new operator.
16 _NAMESCOPE_SEPARATOR = '/'
17 
18 _threadlocal_scope = threading.local()
19 
20 
21 def CurrentNameScope():
22  global _threadlocal_scope
23  if not hasattr(_threadlocal_scope, "namescope"):
24  _threadlocal_scope.namescope = ''
25  return _threadlocal_scope.namescope
26 
27 
28 def CurrentDeviceScope():
29  global _threadlocal_scope
30  if not hasattr(_threadlocal_scope, "devicescope"):
31  _threadlocal_scope.devicescope = None
32  return _threadlocal_scope.devicescope
33 
34 
35 @contextlib.contextmanager
36 def NameScope(prefix, reset=False):
37  global _threadlocal_scope
38  assert isinstance(prefix, basestring) or prefix is None, \
39  "NameScope takes in a string as its argument."
40  old_scope = CurrentNameScope()
41  prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
42  if reset:
43  _threadlocal_scope.namescope = prefix
44  else:
45  _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
46 
47  try:
48  yield
49  finally:
50  assert _threadlocal_scope.namescope.endswith(prefix), \
51  "The namescope variable is changed from outside NameScope() calls."
52  _threadlocal_scope.namescope = old_scope
53 
54 
55 @contextlib.contextmanager
56 def DeviceScope(scope, node_name=None):
57  new_scope = caffe2_pb2.DeviceOption()
58  if scope:
59  assert isinstance(scope, caffe2_pb2.DeviceOption), \
60  "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
61  new_scope.CopyFrom(scope)
62  else:
63  assert node_name, "At least one argument should be non-null in DeviceScope"
64 
65  # rewrite node_name if it is explicitly given
66  if node_name:
67  new_scope.node_name = node_name
68  global _threadlocal_scope
69  old_scope = CurrentDeviceScope()
70  # nested scope should inherit the node_name if it is not explicitly set
71  if old_scope and old_scope.HasField('node_name') and \
72  not new_scope.HasField('node_name'):
73  new_scope.node_name = old_scope.node_name
74 
75  # nested scope should inherit the extra_info and merged it with new extra_info
76  if old_scope and hasattr(old_scope, 'extra_info'):
77  new_scope.extra_info.extend(old_scope.extra_info)
78  new_scope.extra_info.sort()
79 
80  _threadlocal_scope.devicescope = new_scope
81  try:
82  yield
83  finally:
84  assert _threadlocal_scope.devicescope == new_scope, \
85  "The device scope is changed from outside DeviceScope() calls."
86  _threadlocal_scope.devicescope = old_scope
87 
88 
89 @contextlib.contextmanager
90 def EmptyNameScope():
91  """
92  Allow users to 'disable' the name scope behaviour.
93 
94  This sets the CurrentNameScope() to None, so that the field is
95  not set in CreateOperator(...), etc.
96  """
97  old_scope = CurrentNameScope()
98  try:
99  _threadlocal_scope.namescope = ''
100  yield
101  finally:
102  _threadlocal_scope.namescope = old_scope
103  return
104 
105 
106 @contextlib.contextmanager
107 def EmptyDeviceScope():
108  """
109  Allow users to 'disable' the device scope behaviour (so it can be
110  controlled at a NetDef::DeviceOption level, not overridden at
111  OperatorDef::DeviceOption level).
112 
113  This sets the CurrentDeviceScope() to None, so that the field is
114  not set in CreateOperator(...), etc.
115  """
116  old_scope = CurrentDeviceScope()
117  try:
118  _threadlocal_scope.devicescope = None
119  yield
120  finally:
121  _threadlocal_scope.devicescope = old_scope
122  return