Caffe2 - Python API
A deep learning, cross platform ML framework
brew.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package model_helper_api
17 # Module caffe2.python.model_helper_api
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import sys
24 import copy
25 import inspect
26 from past.builtins import basestring
27 from caffe2.python.model_helper import ModelHelper
28 
29 # flake8: noqa
34 from caffe2.python.helpers.conv import *
38 from caffe2.python.helpers.fc import *
42 from caffe2.python.helpers.tools import *
43 from caffe2.python.helpers.train import *
44 
45 
46 class HelperWrapper(object):
47  _registry = {
48  'arg_scope': arg_scope,
49  'fc': fc,
50  'packed_fc': packed_fc,
51  'fc_decomp': fc_decomp,
52  'fc_sparse': fc_sparse,
53  'fc_prune': fc_prune,
54  'dropout': dropout,
55  'max_pool': max_pool,
56  'average_pool': average_pool,
57  'max_pool_with_index' : max_pool_with_index,
58  'lrn': lrn,
59  'softmax': softmax,
60  'instance_norm': instance_norm,
61  'spatial_bn': spatial_bn,
62  'relu': relu,
63  'prelu': prelu,
64  'tanh': tanh,
65  'concat': concat,
66  'depth_concat': depth_concat,
67  'sum': sum,
68  'transpose': transpose,
69  'iter': iter,
70  'accuracy': accuracy,
71  'conv': conv,
72  'conv_nd': conv_nd,
73  'conv_transpose': conv_transpose,
74  'group_conv': group_conv,
75  'group_conv_deprecated': group_conv_deprecated,
76  'image_input': image_input,
77  'video_input': video_input,
78  'add_weight_decay': add_weight_decay,
79  'elementwise_linear': elementwise_linear,
80  'layer_norm': layer_norm,
81  'batch_mat_mul' : batch_mat_mul,
82  'cond' : cond,
83  'loop' : loop,
84  'db_input' : db_input,
85  }
86 
87  def __init__(self, wrapped):
88  self.wrapped = wrapped
89 
90  def __getattr__(self, helper_name):
91  if helper_name not in self._registry:
92  raise AttributeError(
93  "Helper function {} not "
94  "registered.".format(helper_name)
95  )
96 
97  def scope_wrapper(*args, **kwargs):
98  new_kwargs = {}
99  if helper_name != 'arg_scope':
100  if len(args) > 0 and isinstance(args[0], ModelHelper):
101  model = args[0]
102  elif 'model' in kwargs:
103  model = kwargs['model']
104  else:
105  raise RuntimeError(
106  "The first input of helper function should be model. " \
107  "Or you can provide it in kwargs as model=<your_model>.")
108  new_kwargs = copy.deepcopy(model.arg_scope)
109  func = self._registry[helper_name]
110  var_names, _, varkw, _= inspect.getargspec(func)
111  if varkw is None:
112  # this helper function does not take in random **kwargs
113  new_kwargs = {
114  var_name: new_kwargs[var_name]
115  for var_name in var_names if var_name in new_kwargs
116  }
117 
118  cur_scope = get_current_scope()
119  new_kwargs.update(cur_scope.get(helper_name, {}))
120  new_kwargs.update(kwargs)
121  return func(*args, **new_kwargs)
122 
123  scope_wrapper.__name__ = helper_name
124  return scope_wrapper
125 
126  def Register(self, helper):
127  name = helper.__name__
128  if name in self._registry:
129  raise AttributeError(
130  "Helper {} already exists. Please change your "
131  "helper name.".format(name)
132  )
133  self._registry[name] = helper
134 
135  def has_helper(self, helper_or_helper_name):
136  helper_name = (
137  helper_or_helper_name
138  if isinstance(helper_or_helper_name, basestring) else
139  helper_or_helper_name.__name__
140  )
141  return helper_name in self._registry
142 
143 
144 sys.modules[__name__] = HelperWrapper(sys.modules[__name__])