3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 from past.builtins
import basestring
12 from caffe2.proto
import caffe2_pb2
16 _NAMESCOPE_SEPARATOR =
'/' 18 _threadlocal_scope = threading.local()
21 def CurrentNameScope():
22 global _threadlocal_scope
23 if not hasattr(_threadlocal_scope,
"namescope"):
24 _threadlocal_scope.namescope =
'' 25 return _threadlocal_scope.namescope
28 def CurrentDeviceScope():
29 global _threadlocal_scope
30 if not hasattr(_threadlocal_scope,
"devicescope"):
31 _threadlocal_scope.devicescope =
None 32 return _threadlocal_scope.devicescope
35 @contextlib.contextmanager
36 def NameScope(prefix, reset=False):
37 global _threadlocal_scope
38 assert isinstance(prefix, basestring)
or prefix
is None, \
39 "NameScope takes in a string as its argument." 40 old_scope = CurrentNameScope()
41 prefix = prefix + _NAMESCOPE_SEPARATOR
if prefix
else '' 43 _threadlocal_scope.namescope = prefix
45 _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
50 assert _threadlocal_scope.namescope.endswith(prefix), \
51 "The namescope variable is changed from outside NameScope() calls." 52 _threadlocal_scope.namescope = old_scope
55 @contextlib.contextmanager
56 def DeviceScope(scope, node_name=None):
57 new_scope = caffe2_pb2.DeviceOption()
59 assert isinstance(scope, caffe2_pb2.DeviceOption), \
60 "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument." 61 new_scope.CopyFrom(scope)
63 assert node_name,
"At least one argument should be non-null in DeviceScope" 67 new_scope.node_name = node_name
68 global _threadlocal_scope
69 old_scope = CurrentDeviceScope()
71 if old_scope
and old_scope.HasField(
'node_name')
and \
72 not new_scope.HasField(
'node_name'):
73 new_scope.node_name = old_scope.node_name
76 if old_scope
and hasattr(old_scope,
'extra_info'):
77 new_scope.extra_info.extend(old_scope.extra_info)
78 new_scope.extra_info.sort()
80 _threadlocal_scope.devicescope = new_scope
84 assert _threadlocal_scope.devicescope == new_scope, \
85 "The device scope is changed from outside DeviceScope() calls." 86 _threadlocal_scope.devicescope = old_scope
89 @contextlib.contextmanager
92 Allow users to 'disable' the name scope behaviour. 94 This sets the CurrentNameScope() to None, so that the field is 95 not set in CreateOperator(...), etc. 97 old_scope = CurrentNameScope()
99 _threadlocal_scope.namescope =
'' 102 _threadlocal_scope.namescope = old_scope
106 @contextlib.contextmanager
107 def EmptyDeviceScope():
109 Allow users to 'disable' the device scope behaviour (so it can be 110 controlled at a NetDef::DeviceOption level, not overridden at 111 OperatorDef::DeviceOption level). 113 This sets the CurrentDeviceScope() to None, so that the field is 114 not set in CreateOperator(...), etc. 116 old_scope = CurrentDeviceScope()
118 _threadlocal_scope.devicescope =
None 121 _threadlocal_scope.devicescope = old_scope