Caffe2 - Python API
A deep learning, cross platform ML framework
optimizer_context.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package optimizer_context
17 # Module caffe2.python.optimizer_context
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import context
25  ModifierContext, UseModifierBase)
26 
27 
28 DEFAULT_OPTIM = 'DEFAULT'
29 
30 
31 @context.define_context(allow_default=True)
32 class OptimizerContext(ModifierContext):
33  """
34  provide context to allow param_info to have different optimizers
35  """
36 
37  def has_optimizer(self, name):
38  return self._has_modifier(name)
39 
40  def get_optimizer(self, name):
41  assert self.has_optimizer(name), (
42  "{} optimizer is not provided!".format(name))
43  return self._get_modifier(name)
44 
45 
46 class UseOptimizer(UseModifierBase):
47  '''
48  context class to allow setting the current context.
49  Example usage with brew:
50  - with UseOptimizer(optim):
51  brew.func
52  - with UseOptimizer({'WEIGHT': weight_optim}):
53  brew.func
54  - with UseOptimizer({'DEFAULT': optim, 'BIAS': bias_optim,
55  'WEIGHT': weight_optim}):
56  brew.func
57  - with UseOptimizer(optim1):
58  brew.func
59  with UseOptimizer(optim2):
60  brew.func
61 
62  Example useage with layer:
63  optimizers = {'optim1': optim1, 'optim2': optim2}
64  with Optimizers(optimizers):
65  optim = OptimizerContext.current().get_optimizer('optim1')
66  layer(optim=optim)
67  '''
68  def _context_class(self):
69  return OptimizerContext