Caffe2 - Python API
A deep learning, cross platform ML framework
modifier_context.py
1 # @package modifier_context
2 # Module caffe2.python.modifier_context
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 DEFAULT_MODIFIER = 'DEFAULT'
10 
11 
12 class ModifierContext(object):
13  """
14  provide context to allow param_info to have different modifiers
15  """
16 
17  def __init__(self):
18  self._modifiers = {}
19  self._modifiers_list = []
20 
21  def _rebuild_modifiers(self):
22  self._modifiers = {}
23  for m in self._modifiers_list:
24  self._modifiers.update(m)
25 
26  def _has_modifier(self, name):
27  return name in self._modifiers
28 
29  def _get_modifier(self, name):
30  return self._modifiers.get(name)
31 
32  def push_modifiers(self, modifiers):
33  # modifier override is allowed
34  self._modifiers_list.append(modifiers)
35  self._modifiers.update(modifiers)
36 
37  def pop_modifiers(self):
38  assert len(self._modifiers_list) > 0
39  self._modifiers_list.pop()
40  self._rebuild_modifiers()
41 
42 
43 class UseModifierBase(object):
44  '''
45  context class to allow setting the current context.
46  Example useage with layer:
47  modifiers = {'modifier1': modifier1, 'modifier2': modifier2}
48  with Modifiers(modifiers):
49  modifier = ModifierContext.current().get_modifier('modifier1')
50  layer(modifier=modifier)
51  '''
52 
53  def __init__(self, modifier_or_dict):
54  if isinstance(modifier_or_dict, dict):
55  self._modifiers = modifier_or_dict
56  else:
57  self._modifiers = {DEFAULT_MODIFIER: modifier_or_dict}
58 
59  def _context_class(self):
60  raise NotImplementedError
61 
62  def __enter__(self):
63  self._context_class().current().push_modifiers(self._modifiers)
64  return self
65 
66  def __exit__(self, type, value, traceback):
67  self._context_class().current().pop_modifiers()