Caffe2 - Python API
A deep learning, cross platform ML framework
arg_scope.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 import contextlib
20 import copy
21 import threading
22 
23 _threadlocal_scope = threading.local()
24 
25 
26 @contextlib.contextmanager
27 def arg_scope(single_helper_or_list, **kwargs):
28  global _threadlocal_scope
29  if not isinstance(single_helper_or_list, list):
30  assert callable(single_helper_or_list), \
31  "arg_scope is only supporting single or a list of helper functions."
32  single_helper_or_list = [single_helper_or_list]
33  old_scope = copy.deepcopy(get_current_scope())
34  for helper in single_helper_or_list:
35  assert callable(helper), \
36  "arg_scope is only supporting a list of callable helper functions."
37  helper_key = helper.__name__
38  if helper_key not in old_scope:
39  _threadlocal_scope.current_scope[helper_key] = {}
40  _threadlocal_scope.current_scope[helper_key].update(kwargs)
41 
42  yield
43  _threadlocal_scope.current_scope = old_scope
44 
45 
46 def get_current_scope():
47  global _threadlocal_scope
48  if not hasattr(_threadlocal_scope, "current_scope"):
49  _threadlocal_scope.current_scope = {}
50  return _threadlocal_scope.current_scope