Caffe2 - Python API
A deep learning, cross platform ML framework
modifier_context.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 # @package modifier_context
17 # Module caffe2.python.modifier_context
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 
24 DEFAULT_MODIFIER = 'DEFAULT'
25 
26 
27 class ModifierContext(object):
28  """
29  provide context to allow param_info to have different modifiers
30  """
31 
32  def __init__(self):
33  self._modifiers = {}
34  self._modifiers_list = []
35 
36  def _rebuild_modifiers(self):
37  self._modifiers = {}
38  for m in self._modifiers_list:
39  self._modifiers.update(m)
40 
41  def _has_modifier(self, name):
42  return name in self._modifiers
43 
44  def _get_modifier(self, name):
45  return self._modifiers.get(name)
46 
47  def push_modifiers(self, modifiers):
48  # modifier override is allowed
49  self._modifiers_list.append(modifiers)
50  self._modifiers.update(modifiers)
51 
52  def pop_modifiers(self):
53  assert len(self._modifiers_list) > 0
54  self._modifiers_list.pop()
55  self._rebuild_modifiers()
56 
57 
58 class UseModifierBase(object):
59  '''
60  context class to allow setting the current context.
61  Example useage with layer:
62  modifiers = {'modifier1': modifier1, 'modifier2': modifier2}
63  with Modifiers(modifiers):
64  modifier = ModifierContext.current().get_modifier('modifier1')
65  layer(modifier=modifier)
66  '''
67 
68  def __init__(self, modifier_or_dict):
69  if isinstance(modifier_or_dict, dict):
70  self._modifiers = modifier_or_dict
71  else:
72  self._modifiers = {DEFAULT_MODIFIER: modifier_or_dict}
73 
74  def _context_class(self):
75  raise NotImplementedError
76 
77  def __enter__(self):
78  self._context_class().current().push_modifiers(self._modifiers)
79  return self
80 
81  def __exit__(self, type, value, traceback):
82  self._context_class().current().pop_modifiers()