Caffe2 - Python API
A deep learning, cross platform ML framework
context.py
1 ## @package context
2 # Module caffe2.python.context
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import threading
9 import six
10 
11 
12 class _ContextInfo(object):
13  def __init__(self, cls, allow_default, arg_name):
14  self.cls = cls
15  self.allow_default = allow_default
16  self.arg_name = arg_name
17  self._local_stack = threading.local()
18 
19  @property
20  def _stack(self):
21  if not hasattr(self._local_stack, 'obj'):
22  self._local_stack.obj = []
23  return self._local_stack.obj
24 
25  def enter(self, value):
26  self._stack.append(value)
27 
28  def exit(self, value):
29  assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
30  assert self._stack.pop() == value
31 
32  def get_active(self, required=True):
33  if len(self._stack) == 0:
34  if not required:
35  return None
36  assert self.allow_default, (
37  'Context %s is required but none is active.' % self.cls)
38  self.enter(self.cls())
39  return self._stack[-1]
40 
41 
42 class _ContextRegistry(object):
43  def __init__(self):
44  self._ctxs = {}
45 
46  def register(self, ctx_info):
47  assert isinstance(ctx_info, _ContextInfo)
48  assert (ctx_info.cls not in self._ctxs), (
49  'Context %s already registered' % ctx_info.cls)
50  self._ctxs[ctx_info.cls] = ctx_info
51 
52  def get(self, cls):
53  assert cls in self._ctxs, 'Context %s not registered.' % cls
54  return self._ctxs[cls]
55 
56 
57 _CONTEXT_REGISTRY = _ContextRegistry()
58 
59 
60 def _context_registry():
61  global _CONTEXT_REGISTRY
62  return _CONTEXT_REGISTRY
63 
64 
65 def __enter__(self):
66  if self._prev_enter is not None:
67  self._prev_enter()
68  _context_registry().get(self._ctx_class).enter(self)
69  return self
70 
71 
72 def __exit__(self, *args):
73  _context_registry().get(self._ctx_class).exit(self)
74  if self._prev_exit is not None:
75  self._prev_exit(*args)
76 
77 
78 def __call__(self, func):
79  @six.wraps(func)
80  def wrapper(*args, **kwargs):
81  with self:
82  return func(*args, **kwargs)
83  return wrapper
84 
85 
86 @classmethod
87 def _current(cls, value=None, required=True):
88  return _get_active_context(cls, value, required)
89 
90 
91 class define_context(object):
92  def __init__(self, arg_name=None, allow_default=False):
93  self.arg_name = arg_name
94  self.allow_default = allow_default
95 
96  def __call__(self, cls):
97  assert not hasattr(cls, '_ctx_class'), (
98  '%s parent class (%s) already defines context.' % (
99  cls, cls._ctx_class))
100  cls._ctx_class = cls
101 
102  _context_registry().register(
103  _ContextInfo(cls, self.allow_default, self.arg_name)
104  )
105 
106  cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None
107  cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None
108 
109  cls.__enter__ = __enter__
110  cls.__exit__ = __exit__
111  cls.__call__ = __call__
112  cls.current = _current
113 
114  return cls
115 
116 
117 def _get_active_context(cls, val=None, required=True):
118  ctx_info = _context_registry().get(cls)
119  if val is not None:
120  assert isinstance(val, cls), (
121  'Wrong context type. Expected: %s, got %s.' % (cls, type(val)))
122  return val
123  return ctx_info.get_active(required=required)