Caffe2 - Python API
A deep learning, cross platform ML framework
parameter_sharing.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 from caffe2.python import scope
22 
23 import contextlib
24 import logging
25 
26 logger = logging.getLogger(__name__)
27 
28 
30  """
31  This class manages scope driven way of parameter sharing across different
32  NameScopes.
33  """
34 
35  def __init__(self):
36  self._scope_overrides = {}
37  self._contexts = []
38 
39  def _resolve_scope_overrides(self, candidate_scope):
40  """
41  Recursively resolves all scope overrides, i.e multiple steps of
42  override can be used.
43 
44  For example, if one provides following scope overrides:
45  {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''},
46  then name 'w' will get resolved to the following blobs depending on the
47  namescope:
48  a. 'scope_a' -> 'scope_a/w'
49  b. 'scope_b' -> 'scope_a/w'
50  c. 'scope_c' -> 'scope_c/w'
51  d. 'scope_b/shared_child' -> 'scope_a/w'
52  d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w'
53  """
54  best_scope = candidate_scope
55  best_scope_idx = 0
56  sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR)
57 
58  cur_scope = ''
59  for idx, sub_scope in enumerate(sub_scopes):
60  cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR
61  if cur_scope in self._scope_overrides:
62  best_scope = self._scope_overrides[cur_scope]
63  best_scope_idx = idx
64  if best_scope == candidate_scope:
65  return candidate_scope
66  else:
67  return (self._resolve_scope_overrides(best_scope) +
68  scope._NAMESCOPE_SEPARATOR.join(
69  sub_scopes[best_scope_idx + 1:]))
70 
71  def get_parameter_name(self, name):
72  candidate_scope = scope.CurrentNameScope()
73  best_scope = self._resolve_scope_overrides(candidate_scope)
74  if best_scope != candidate_scope:
75  logger.info("Overwiting scope {0} with scope {1}".format(
76  candidate_scope, best_scope))
77 
78  return best_scope + name
79 
80  def add_scope_overrides(self, shared_scopes):
81  self._contexts.append(shared_scopes)
82  self._scope_overrides.update(shared_scopes)
83 
84  def pop(self):
85  assert len(self._contexts) > 0
86  self._contexts.pop()
87  self._scope_overrides = {}
88  for x in self._contexts:
89  self._scope_overrides.update(x)
90 
91 
92 parameter_sharing_context = ParameterSharingContext()
93 
94 
95 def _normalize_namescope(namescope):
96  if namescope and namescope[-1] != scope._NAMESCOPE_SEPARATOR:
97  return namescope + scope._NAMESCOPE_SEPARATOR
98  else:
99  return namescope
100 
101 
102 @contextlib.contextmanager
103 def ParameterSharing(shared_scopes):
104  """
105  Helper function for sharing scopes.
106  All the parameters within the shared_scopes, will be remapped with the
107  respect of CurrentNamescope()
108 
109  I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the
110  scope 'some_global_scope', it'll effectively mean, that all parameters from
111  'some_global_scope/scope_b' will shared with the parameters from
112  'some_global_scope/scope_a'
113  """
114  assert isinstance(shared_scopes, dict)
115 
116  shared_scope_overrides = {}
117  current_scope = scope.CurrentNameScope()
118  for k, v in shared_scopes.items():
119  assert not v.startswith(k), (
120  "Illegal override for parameter sharing. {} is prefix of {}".
121  format(k, v))
122  k = current_scope + k
123  v = current_scope + v
124  # Normalize all the scopes, so scope_a and scope_a/ are equivalent
125  k = _normalize_namescope(k)
126  v = _normalize_namescope(v)
127  shared_scope_overrides[k] = v
128 
129  try:
130  parameter_sharing_context.add_scope_overrides(shared_scope_overrides)
131  yield
132  finally:
133  parameter_sharing_context.pop()