Caffe2 - Python API
A deep learning, cross platform ML framework
arg_scope.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 import contextlib
5 import copy
6 import threading
7 
8 _threadlocal_scope = threading.local()
9 
10 
11 @contextlib.contextmanager
12 def arg_scope(single_helper_or_list, **kwargs):
13  global _threadlocal_scope
14  if not isinstance(single_helper_or_list, list):
15  assert callable(single_helper_or_list), \
16  "arg_scope is only supporting single or a list of helper functions."
17  single_helper_or_list = [single_helper_or_list]
18  old_scope = copy.deepcopy(get_current_scope())
19  for helper in single_helper_or_list:
20  assert callable(helper), \
21  "arg_scope is only supporting a list of callable helper functions."
22  helper_key = helper.__name__
23  if helper_key not in old_scope:
24  _threadlocal_scope.current_scope[helper_key] = {}
25  _threadlocal_scope.current_scope[helper_key].update(kwargs)
26 
27  yield
28  _threadlocal_scope.current_scope = old_scope
29 
30 
31 def get_current_scope():
32  global _threadlocal_scope
33  if not hasattr(_threadlocal_scope, "current_scope"):
34  _threadlocal_scope.current_scope = {}
35  return _threadlocal_scope.current_scope