Caffe2 - Python API
A deep learning, cross platform ML framework
context.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package context
17 # Module caffe2.python.context
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import threading
24 import six
25 
26 
27 class ContextInfo(object):
28  def __init__(self, cls, allow_default, arg_name):
29  self.cls = cls
30  self.allow_default = allow_default
31  self.arg_name = arg_name
32  self._local_stack = threading.local()
33 
34  @property
35  def _stack(self):
36  if not hasattr(self._local_stack, 'obj'):
37  self._local_stack.obj = []
38  return self._local_stack.obj
39 
40  def enter(self, value):
41  self._stack.append(value)
42 
43  def exit(self, value):
44  assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
45  assert self._stack.pop() == value
46 
47  def get_active(self, required=True):
48  if len(self._stack) == 0:
49  if not required:
50  return None
51  assert self.allow_default, (
52  'Context %s is required but none is active.' % self.cls)
53  self.enter(self.cls())
54  return self._stack[-1]
55 
56 
57 class ContextManager(object):
58  def __init__(self):
59  self._ctxs = {}
60 
61  def register(self, ctx_info):
62  assert isinstance(ctx_info, ContextInfo)
63  assert (ctx_info.cls not in self._ctxs), (
64  'Context %s already registered' % ctx_info.cls)
65  self._ctxs[ctx_info.cls] = ctx_info
66 
67  def get(self, cls):
68  assert cls in self._ctxs, 'Context %s not registered.' % cls
69  return self._ctxs[cls]
70 
71 
72 _CONTEXT_MANAGER = ContextManager()
73 
74 
75 def context_manager():
76  global _CONTEXT_MANAGER
77  return _CONTEXT_MANAGER
78 
79 
80 def __enter__(self):
81  if self._prev_enter is not None:
82  self._prev_enter()
83  context_manager().get(self._ctx_class).enter(self)
84  return self
85 
86 
87 def __exit__(self, *args):
88  context_manager().get(self._ctx_class).exit(self)
89  if self._prev_exit is not None:
90  self._prev_exit(*args)
91 
92 
93 def __call__(self, func):
94  @six.wraps(func)
95  def wrapper(*args, **kwargs):
96  with self:
97  return func(*args, **kwargs)
98  return wrapper
99 
100 
101 @classmethod
102 def current(cls, value=None, required=True):
103  return get_active_context(cls, value, required)
104 
105 
106 class define_context(object):
107  def __init__(self, arg_name=None, allow_default=False):
108  self.arg_name = arg_name
109  self.allow_default = allow_default
110 
111  def __call__(self, cls):
112  assert not hasattr(cls, '_ctx_class'), (
113  '%s parent class (%s) already defines context.' % (
114  cls, cls._ctx_class))
115  context_manager().register(
116  ContextInfo(cls, self.allow_default, self.arg_name))
117  cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None
118  cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None
119  cls._ctx_class = cls
120  cls.__enter__ = __enter__
121  cls.__exit__ = __exit__
122  cls.__call__ = __call__
123  cls.current = current
124  return cls
125 
126 
127 def get_active_context(cls, val=None, required=True):
128  ctx_info = context_manager().get(cls)
129  if val is not None:
130  assert isinstance(val, cls), (
131  'Wrong context type. Expected: %s, got %s.' % (cls, type(val)))
132  return val
133  return ctx_info.get_active(required=required)