3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 from past.builtins
import basestring
33 'arg_scope': arg_scope,
35 'packed_fc': packed_fc,
36 'fc_decomp': fc_decomp,
37 'fc_sparse': fc_sparse,
41 'average_pool': average_pool,
42 'max_pool_with_index' : max_pool_with_index,
45 'instance_norm': instance_norm,
46 'spatial_bn': spatial_bn,
47 'spatial_gn': spatial_gn,
52 'depth_concat': depth_concat,
54 'transpose': transpose,
59 'conv_transpose': conv_transpose,
60 'group_conv': group_conv,
61 'group_conv_deprecated': group_conv_deprecated,
62 'image_input': image_input,
63 'video_input': video_input,
64 'add_weight_decay': add_weight_decay,
65 'elementwise_linear': elementwise_linear,
66 'layer_norm': layer_norm,
67 'batch_mat_mul' : batch_mat_mul,
70 'db_input' : db_input,
73 def __init__(self, wrapped):
76 def __getattr__(self, helper_name):
79 "Helper function {} not " 80 "registered.".format(helper_name)
83 def scope_wrapper(*args, **kwargs):
85 if helper_name !=
'arg_scope':
86 if len(args) > 0
and isinstance(args[0], ModelHelper):
88 elif 'model' in kwargs:
89 model = kwargs[
'model']
92 "The first input of helper function should be model. " \
93 "Or you can provide it in kwargs as model=<your_model>.")
94 new_kwargs = copy.deepcopy(model.arg_scope)
96 var_names, _, varkw, _= inspect.getargspec(func)
100 var_name: new_kwargs[var_name]
101 for var_name
in var_names
if var_name
in new_kwargs
104 cur_scope = get_current_scope()
105 new_kwargs.update(cur_scope.get(helper_name, {}))
106 new_kwargs.update(kwargs)
107 return func(*args, **new_kwargs)
109 scope_wrapper.__name__ = helper_name
112 def Register(self, helper):
113 name = helper.__name__
115 raise AttributeError(
116 "Helper {} already exists. Please change your " 117 "helper name.".format(name)
121 def has_helper(self, helper_or_helper_name):
123 helper_or_helper_name
124 if isinstance(helper_or_helper_name, basestring)
else 125 helper_or_helper_name.__name__