Caffe2 - Python API
A deep learning, cross platform ML framework
scope.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package scope
17 # Module caffe2.python.scope
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import contextlib
24 import threading
25 from past.builtins import basestring
26 
27 from caffe2.proto import caffe2_pb2
28 
29 
30 # The name scope and device scope when creating a new operator.
31 _NAMESCOPE_SEPARATOR = '/'
32 
33 _threadlocal_scope = threading.local()
34 
35 
36 def CurrentNameScope():
37  global _threadlocal_scope
38  if not hasattr(_threadlocal_scope, "namescope"):
39  _threadlocal_scope.namescope = ''
40  return _threadlocal_scope.namescope
41 
42 
43 def CurrentDeviceScope():
44  global _threadlocal_scope
45  if not hasattr(_threadlocal_scope, "devicescope"):
46  _threadlocal_scope.devicescope = None
47  return _threadlocal_scope.devicescope
48 
49 
50 @contextlib.contextmanager
51 def NameScope(prefix, reset=False):
52  global _threadlocal_scope
53  assert isinstance(prefix, basestring) or prefix is None, \
54  "NameScope takes in a string as its argument."
55  old_scope = CurrentNameScope()
56  prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else ''
57  if reset:
58  _threadlocal_scope.namescope = prefix
59  else:
60  _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix
61 
62  try:
63  yield
64  finally:
65  assert _threadlocal_scope.namescope.endswith(prefix), \
66  "The namescope variable is changed from outside NameScope() calls."
67  _threadlocal_scope.namescope = old_scope
68 
69 
70 @contextlib.contextmanager
71 def DeviceScope(scope, node_name=None):
72  new_scope = caffe2_pb2.DeviceOption()
73  if scope:
74  assert isinstance(scope, caffe2_pb2.DeviceOption), \
75  "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument."
76  new_scope.CopyFrom(scope)
77  else:
78  assert node_name, "At least one argument should be non-null in DeviceScope"
79 
80  # rewrite node_name if it is explicitly given
81  if node_name:
82  new_scope.node_name = node_name
83  global _threadlocal_scope
84  old_scope = CurrentDeviceScope()
85  # nested scope should inherit the node_name if it is not explicitly set
86  if old_scope and old_scope.HasField('node_name') and \
87  not new_scope.HasField('node_name'):
88  new_scope.node_name = old_scope.node_name
89  _threadlocal_scope.devicescope = new_scope
90  try:
91  yield
92  finally:
93  assert _threadlocal_scope.devicescope == new_scope, \
94  "The device scope is changed from outside DeviceScope() calls."
95  _threadlocal_scope.devicescope = old_scope
96 
97 
98 @contextlib.contextmanager
99 def EmptyDeviceScope():
100  """
101  Allow users to 'disable' the device scope behaviour (so it can be
102  controlled at a NetDef::DeviceOption level, not overridden at
103  OperatorDef::DeviceOption level).
104 
105  This sets the CurrentDeviceScope() to None, so that the field is
106  not set in CreateOperator(...), etc.
107  """
108  old_scope = CurrentDeviceScope()
109  try:
110  _threadlocal_scope.devicescope = None
111  yield
112  finally:
113  _threadlocal_scope.devicescope = old_scope
114  return