Caffe2 - Python API
A deep learning, cross platform ML framework
parameter_sharing.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import scope
7 
8 import contextlib
9 import logging
10 
11 logger = logging.getLogger(__name__)
12 
13 
15  """
16  This class manages scope driven way of parameter sharing across different
17  NameScopes.
18  """
19 
20  def __init__(self):
21  self._scope_overrides = {}
22  self._contexts = []
23 
24  def _resolve_scope_overrides(self, candidate_scope):
25  """
26  Recursively resolves all scope overrides, i.e multiple steps of
27  override can be used.
28 
29  For example, if one provides following scope overrides:
30  {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''},
31  then name 'w' will get resolved to the following blobs depending on the
32  namescope:
33  a. 'scope_a' -> 'scope_a/w'
34  b. 'scope_b' -> 'scope_a/w'
35  c. 'scope_c' -> 'scope_c/w'
36  d. 'scope_b/shared_child' -> 'scope_a/w'
37  d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w'
38  """
39  best_scope = candidate_scope
40  best_scope_idx = 0
41  sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR)
42 
43  cur_scope = ''
44  for idx, sub_scope in enumerate(sub_scopes):
45  cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR
46  if cur_scope in self._scope_overrides:
47  best_scope = self._scope_overrides[cur_scope]
48  best_scope_idx = idx
49  if best_scope == candidate_scope:
50  return candidate_scope
51  else:
52  return (self._resolve_scope_overrides(best_scope) +
53  scope._NAMESCOPE_SEPARATOR.join(
54  sub_scopes[best_scope_idx + 1:]))
55 
56  def get_parameter_name(self, name):
57  candidate_scope = scope.CurrentNameScope()
58  best_scope = self._resolve_scope_overrides(candidate_scope)
59  if best_scope != candidate_scope:
60  logger.info("Overwiting scope {0} with scope {1}".format(
61  candidate_scope, best_scope))
62 
63  return best_scope + name
64 
65  def add_scope_overrides(self, shared_scopes):
66  self._contexts.append(shared_scopes)
67  self._scope_overrides.update(shared_scopes)
68 
69  def pop(self):
70  assert len(self._contexts) > 0
71  self._contexts.pop()
72  self._scope_overrides = {}
73  for x in self._contexts:
74  self._scope_overrides.update(x)
75 
76 
77 parameter_sharing_context = ParameterSharingContext()
78 
79 
80 def _normalize_namescope(namescope):
81  if namescope and namescope[-1] != scope._NAMESCOPE_SEPARATOR:
82  return namescope + scope._NAMESCOPE_SEPARATOR
83  else:
84  return namescope
85 
86 
87 @contextlib.contextmanager
88 def ParameterSharing(shared_scopes):
89  """
90  Helper function for sharing scopes.
91  All the parameters within the shared_scopes, will be remapped with the
92  respect of CurrentNamescope()
93 
94  I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the
95  scope 'some_global_scope', it'll effectively mean, that all parameters from
96  'some_global_scope/scope_b' will shared with the parameters from
97  'some_global_scope/scope_a'
98  """
99  assert isinstance(shared_scopes, dict)
100 
101  shared_scope_overrides = {}
102  current_scope = scope.CurrentNameScope()
103  for k, v in shared_scopes.items():
104  assert not v.startswith(k), (
105  "Illegal override for parameter sharing. {} is prefix of {}".
106  format(k, v))
107  k = current_scope + k
108  v = current_scope + v
109  # Normalize all the scopes, so scope_a and scope_a/ are equivalent
110  k = _normalize_namescope(k)
111  v = _normalize_namescope(v)
112  shared_scope_overrides[k] = v
113 
114  try:
115  parameter_sharing_context.add_scope_overrides(shared_scope_overrides)
116  yield
117  finally:
118  parameter_sharing_context.pop()