1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
11 logger = logging.getLogger(__name__)
16 This class manages scope driven way of parameter sharing across different 24 def _resolve_scope_overrides(self, candidate_scope):
26 Recursively resolves all scope overrides, i.e multiple steps of 29 For example, if one provides following scope overrides: 30 {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''}, 31 then name 'w' will get resolved to the following blobs depending on the 33 a. 'scope_a' -> 'scope_a/w' 34 b. 'scope_b' -> 'scope_a/w' 35 c. 'scope_c' -> 'scope_c/w' 36 d. 'scope_b/shared_child' -> 'scope_a/w' 37 d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w' 39 best_scope = candidate_scope
41 sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR)
44 for idx, sub_scope
in enumerate(sub_scopes):
45 cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR
49 if best_scope == candidate_scope:
50 return candidate_scope
53 scope._NAMESCOPE_SEPARATOR.join(
54 sub_scopes[best_scope_idx + 1:]))
56 def get_parameter_name(self, name):
57 candidate_scope = scope.CurrentNameScope()
59 if best_scope != candidate_scope:
60 logger.info(
"Overwiting scope {0} with scope {1}".format(
61 candidate_scope, best_scope))
63 return best_scope + name
65 def add_scope_overrides(self, shared_scopes):
66 self._contexts.append(shared_scopes)
67 self._scope_overrides.update(shared_scopes)
74 self._scope_overrides.update(x)
80 def _normalize_namescope(namescope):
81 if namescope
and namescope[-1] != scope._NAMESCOPE_SEPARATOR:
82 return namescope + scope._NAMESCOPE_SEPARATOR
87 @contextlib.contextmanager
88 def ParameterSharing(shared_scopes):
90 Helper function for sharing scopes. 91 All the parameters within the shared_scopes, will be remapped with the 92 respect of CurrentNamescope() 94 I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the 95 scope 'some_global_scope', it'll effectively mean, that all parameters from 96 'some_global_scope/scope_b' will shared with the parameters from 97 'some_global_scope/scope_a' 99 assert isinstance(shared_scopes, dict)
101 shared_scope_overrides = {}
102 current_scope = scope.CurrentNameScope()
103 for k, v
in shared_scopes.items():
104 assert not v.startswith(k), (
105 "Illegal override for parameter sharing. {} is prefix of {}".
107 k = current_scope + k
108 v = current_scope + v
110 k = _normalize_namescope(k)
111 v = _normalize_namescope(v)
112 shared_scope_overrides[k] = v
115 parameter_sharing_context.add_scope_overrides(shared_scope_overrides)
118 parameter_sharing_context.pop()
def _resolve_scope_overrides(self, candidate_scope)