Caffe2 - Python API
A deep learning, cross platform ML framework
brew.py
1 ## @package model_helper_api
2 # Module caffe2.python.model_helper_api
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import sys
9 import copy
10 import inspect
11 from past.builtins import basestring
12 from caffe2.python.model_helper import ModelHelper
13 
14 # flake8: noqa
19 from caffe2.python.helpers.conv import *
23 from caffe2.python.helpers.fc import *
27 from caffe2.python.helpers.tools import *
28 from caffe2.python.helpers.train import *
29 
30 
31 class HelperWrapper(object):
32  _registry = {
33  'arg_scope': arg_scope,
34  'fc': fc,
35  'packed_fc': packed_fc,
36  'fc_decomp': fc_decomp,
37  'fc_sparse': fc_sparse,
38  'fc_prune': fc_prune,
39  'dropout': dropout,
40  'max_pool': max_pool,
41  'average_pool': average_pool,
42  'max_pool_with_index' : max_pool_with_index,
43  'lrn': lrn,
44  'softmax': softmax,
45  'instance_norm': instance_norm,
46  'spatial_bn': spatial_bn,
47  'spatial_gn': spatial_gn,
48  'relu': relu,
49  'prelu': prelu,
50  'tanh': tanh,
51  'concat': concat,
52  'depth_concat': depth_concat,
53  'sum': sum,
54  'transpose': transpose,
55  'iter': iter,
56  'accuracy': accuracy,
57  'conv': conv,
58  'conv_nd': conv_nd,
59  'conv_transpose': conv_transpose,
60  'group_conv': group_conv,
61  'group_conv_deprecated': group_conv_deprecated,
62  'image_input': image_input,
63  'video_input': video_input,
64  'add_weight_decay': add_weight_decay,
65  'elementwise_linear': elementwise_linear,
66  'layer_norm': layer_norm,
67  'batch_mat_mul' : batch_mat_mul,
68  'cond' : cond,
69  'loop' : loop,
70  'db_input' : db_input,
71  }
72 
73  def __init__(self, wrapped):
74  self.wrapped = wrapped
75 
76  def __getattr__(self, helper_name):
77  if helper_name not in self._registry:
78  raise AttributeError(
79  "Helper function {} not "
80  "registered.".format(helper_name)
81  )
82 
83  def scope_wrapper(*args, **kwargs):
84  new_kwargs = {}
85  if helper_name != 'arg_scope':
86  if len(args) > 0 and isinstance(args[0], ModelHelper):
87  model = args[0]
88  elif 'model' in kwargs:
89  model = kwargs['model']
90  else:
91  raise RuntimeError(
92  "The first input of helper function should be model. " \
93  "Or you can provide it in kwargs as model=<your_model>.")
94  new_kwargs = copy.deepcopy(model.arg_scope)
95  func = self._registry[helper_name]
96  var_names, _, varkw, _= inspect.getargspec(func)
97  if varkw is None:
98  # this helper function does not take in random **kwargs
99  new_kwargs = {
100  var_name: new_kwargs[var_name]
101  for var_name in var_names if var_name in new_kwargs
102  }
103 
104  cur_scope = get_current_scope()
105  new_kwargs.update(cur_scope.get(helper_name, {}))
106  new_kwargs.update(kwargs)
107  return func(*args, **new_kwargs)
108 
109  scope_wrapper.__name__ = helper_name
110  return scope_wrapper
111 
112  def Register(self, helper):
113  name = helper.__name__
114  if name in self._registry:
115  raise AttributeError(
116  "Helper {} already exists. Please change your "
117  "helper name.".format(name)
118  )
119  self._registry[name] = helper
120 
121  def has_helper(self, helper_or_helper_name):
122  helper_name = (
123  helper_or_helper_name
124  if isinstance(helper_or_helper_name, basestring) else
125  helper_or_helper_name.__name__
126  )
127  return helper_name in self._registry
128 
129 
130 sys.modules[__name__] = HelperWrapper(sys.modules[__name__])