3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 from caffe2.proto
import caffe2_pb2
15 """A helper model so we can write CNN models more easily, without having to 16 manually define parameter initializations and operators separately. 19 def __init__(self, order="NCHW", name=None,
20 use_cudnn=
True, cudnn_exhaustive_search=
False,
21 ws_nbytes_limit=
None, init_params=
True,
22 skip_sparse_optim=
False,
25 "[====DEPRECATE WARNING====]: you are creating an " 26 "object from CNNModelHelper class which will be deprecated soon. " 27 "Please use ModelHelper object with brew module. For more " 28 "information, please refer to caffe2.ai and python/brew.py, " 29 "python/brew_test.py for more information." 34 'use_cudnn': use_cudnn,
35 'cudnn_exhaustive_search': cudnn_exhaustive_search,
38 cnn_arg_scope[
'ws_nbytes_limit'] = ws_nbytes_limit
39 super(CNNModelHelper, self).__init__(
40 skip_sparse_optim=skip_sparse_optim,
41 name=
"CNN" if name
is None else name,
42 init_params=init_params,
43 param_model=param_model,
44 arg_scope=cnn_arg_scope,
51 if self.
order !=
"NHWC" and self.
order !=
"NCHW":
53 "Cannot understand the CNN storage order %s." % self.
order 56 def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
57 return brew.image_input(
62 use_gpu_transform=use_gpu_transform,
66 def VideoInput(self, blob_in, blob_out, **kwargs):
67 return brew.video_input(
74 def PadImage(self, blob_in, blob_out, **kwargs):
76 self.net.PadImage(blob_in, blob_out, **kwargs)
78 def ConvNd(self, *args, **kwargs):
89 def Conv(self, *args, **kwargs):
100 def ConvTranspose(self, *args, **kwargs):
101 return brew.conv_transpose(
111 def GroupConv(self, *args, **kwargs):
112 return brew.group_conv(
122 def GroupConv_Deprecated(self, *args, **kwargs):
123 return brew.group_conv_deprecated(
133 def FC(self, *args, **kwargs):
134 return brew.fc(self, *args, **kwargs)
136 def PackedFC(self, *args, **kwargs):
137 return brew.packed_fc(self, *args, **kwargs)
139 def FC_Prune(self, *args, **kwargs):
140 return brew.fc_prune(self, *args, **kwargs)
142 def FC_Decomp(self, *args, **kwargs):
143 return brew.fc_decomp(self, *args, **kwargs)
145 def FC_Sparse(self, *args, **kwargs):
146 return brew.fc_sparse(self, *args, **kwargs)
148 def Dropout(self, *args, **kwargs):
153 def LRN(self, *args, **kwargs):
158 def Softmax(self, *args, **kwargs):
159 return brew.softmax(self, *args, use_cudnn=self.
use_cudnn, **kwargs)
161 def SpatialBN(self, *args, **kwargs):
162 return brew.spatial_bn(self, *args, order=self.
order, **kwargs)
164 def SpatialGN(self, *args, **kwargs):
165 return brew.spatial_gn(self, *args, order=self.
order, **kwargs)
167 def InstanceNorm(self, *args, **kwargs):
168 return brew.instance_norm(self, *args, order=self.
order, **kwargs)
170 def Relu(self, *args, **kwargs):
175 def PRelu(self, *args, **kwargs):
176 return brew.prelu(self, *args, **kwargs)
178 def Concat(self, *args, **kwargs):
179 return brew.concat(self, *args, order=self.
order, **kwargs)
182 """The old depth concat function - we should move to use concat.""" 183 print(
"DepthConcat is deprecated. use Concat instead.")
184 return self.
Concat(*args, **kwargs)
186 def Sum(self, *args, **kwargs):
187 return brew.sum(self, *args, **kwargs)
189 def Transpose(self, *args, **kwargs):
190 return brew.transpose(self, *args, use_cudnn=self.
use_cudnn, **kwargs)
192 def Iter(self, *args, **kwargs):
193 return brew.iter(self, *args, **kwargs)
195 def Accuracy(self, *args, **kwargs):
196 return brew.accuracy(self, *args, **kwargs)
198 def MaxPool(self, *args, **kwargs):
199 return brew.max_pool(
203 def MaxPoolWithIndex(self, *args, **kwargs):
204 return brew.max_pool_with_index(self, *args, order=self.
order, **kwargs)
206 def AveragePool(self, *args, **kwargs):
207 return brew.average_pool(
212 def XavierInit(self):
213 return (
'XavierFill', {})
215 def ConstantInit(self, value):
216 return (
'ConstantFill', dict(value=value))
220 return (
'MSRAFill', {})
224 return (
'ConstantFill', {})
226 def AddWeightDecay(self, weight_decay):
227 return brew.add_weight_decay(self, weight_decay)
231 device_option = caffe2_pb2.DeviceOption()
232 device_option.device_type = caffe2_pb2.CPU
236 def GPU(self, gpu_id=0):
237 device_option = caffe2_pb2.DeviceOption()
238 device_option.device_type = workspace.GpuDeviceType
239 device_option.device_id = gpu_id
def DepthConcat(self, args, kwargs)
def Concat(self, args, kwargs)