Caffe2 - Python API
A deep learning, cross platform ML framework
cnn.py
1 ## @package cnn
2 # Module caffe2.python.cnn
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import brew, workspace
9 from caffe2.python.model_helper import ModelHelper
10 from caffe2.proto import caffe2_pb2
11 import logging
12 
13 
15  """A helper model so we can write CNN models more easily, without having to
16  manually define parameter initializations and operators separately.
17  """
18 
19  def __init__(self, order="NCHW", name=None,
20  use_cudnn=True, cudnn_exhaustive_search=False,
21  ws_nbytes_limit=None, init_params=True,
22  skip_sparse_optim=False,
23  param_model=None):
24  logging.warning(
25  "[====DEPRECATE WARNING====]: you are creating an "
26  "object from CNNModelHelper class which will be deprecated soon. "
27  "Please use ModelHelper object with brew module. For more "
28  "information, please refer to caffe2.ai and python/brew.py, "
29  "python/brew_test.py for more information."
30  )
31 
32  cnn_arg_scope = {
33  'order': order,
34  'use_cudnn': use_cudnn,
35  'cudnn_exhaustive_search': cudnn_exhaustive_search,
36  }
37  if ws_nbytes_limit:
38  cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
39  super(CNNModelHelper, self).__init__(
40  skip_sparse_optim=skip_sparse_optim,
41  name="CNN" if name is None else name,
42  init_params=init_params,
43  param_model=param_model,
44  arg_scope=cnn_arg_scope,
45  )
46 
47  self.order = order
48  self.use_cudnn = use_cudnn
49  self.cudnn_exhaustive_search = cudnn_exhaustive_search
50  self.ws_nbytes_limit = ws_nbytes_limit
51  if self.order != "NHWC" and self.order != "NCHW":
52  raise ValueError(
53  "Cannot understand the CNN storage order %s." % self.order
54  )
55 
56  def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
57  return brew.image_input(
58  self,
59  blob_in,
60  blob_out,
61  order=self.order,
62  use_gpu_transform=use_gpu_transform,
63  **kwargs
64  )
65 
66  def VideoInput(self, blob_in, blob_out, **kwargs):
67  return brew.video_input(
68  self,
69  blob_in,
70  blob_out,
71  **kwargs
72  )
73 
74  def PadImage(self, blob_in, blob_out, **kwargs):
75  # TODO(wyiming): remove this dummy helper later
76  self.net.PadImage(blob_in, blob_out, **kwargs)
77 
78  def ConvNd(self, *args, **kwargs):
79  return brew.conv_nd(
80  self,
81  *args,
82  use_cudnn=self.use_cudnn,
83  order=self.order,
84  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
85  ws_nbytes_limit=self.ws_nbytes_limit,
86  **kwargs
87  )
88 
89  def Conv(self, *args, **kwargs):
90  return brew.conv(
91  self,
92  *args,
93  use_cudnn=self.use_cudnn,
94  order=self.order,
95  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
96  ws_nbytes_limit=self.ws_nbytes_limit,
97  **kwargs
98  )
99 
100  def ConvTranspose(self, *args, **kwargs):
101  return brew.conv_transpose(
102  self,
103  *args,
104  use_cudnn=self.use_cudnn,
105  order=self.order,
106  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
107  ws_nbytes_limit=self.ws_nbytes_limit,
108  **kwargs
109  )
110 
111  def GroupConv(self, *args, **kwargs):
112  return brew.group_conv(
113  self,
114  *args,
115  use_cudnn=self.use_cudnn,
116  order=self.order,
117  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
118  ws_nbytes_limit=self.ws_nbytes_limit,
119  **kwargs
120  )
121 
122  def GroupConv_Deprecated(self, *args, **kwargs):
123  return brew.group_conv_deprecated(
124  self,
125  *args,
126  use_cudnn=self.use_cudnn,
127  order=self.order,
128  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
129  ws_nbytes_limit=self.ws_nbytes_limit,
130  **kwargs
131  )
132 
133  def FC(self, *args, **kwargs):
134  return brew.fc(self, *args, **kwargs)
135 
136  def PackedFC(self, *args, **kwargs):
137  return brew.packed_fc(self, *args, **kwargs)
138 
139  def FC_Prune(self, *args, **kwargs):
140  return brew.fc_prune(self, *args, **kwargs)
141 
142  def FC_Decomp(self, *args, **kwargs):
143  return brew.fc_decomp(self, *args, **kwargs)
144 
145  def FC_Sparse(self, *args, **kwargs):
146  return brew.fc_sparse(self, *args, **kwargs)
147 
148  def Dropout(self, *args, **kwargs):
149  return brew.dropout(
150  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
151  )
152 
153  def LRN(self, *args, **kwargs):
154  return brew.lrn(
155  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
156  )
157 
158  def Softmax(self, *args, **kwargs):
159  return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
160 
161  def SpatialBN(self, *args, **kwargs):
162  return brew.spatial_bn(self, *args, order=self.order, **kwargs)
163 
164  def SpatialGN(self, *args, **kwargs):
165  return brew.spatial_gn(self, *args, order=self.order, **kwargs)
166 
167  def InstanceNorm(self, *args, **kwargs):
168  return brew.instance_norm(self, *args, order=self.order, **kwargs)
169 
170  def Relu(self, *args, **kwargs):
171  return brew.relu(
172  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
173  )
174 
175  def PRelu(self, *args, **kwargs):
176  return brew.prelu(self, *args, **kwargs)
177 
178  def Concat(self, *args, **kwargs):
179  return brew.concat(self, *args, order=self.order, **kwargs)
180 
181  def DepthConcat(self, *args, **kwargs):
182  """The old depth concat function - we should move to use concat."""
183  print("DepthConcat is deprecated. use Concat instead.")
184  return self.Concat(*args, **kwargs)
185 
186  def Sum(self, *args, **kwargs):
187  return brew.sum(self, *args, **kwargs)
188 
189  def Transpose(self, *args, **kwargs):
190  return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
191 
192  def Iter(self, *args, **kwargs):
193  return brew.iter(self, *args, **kwargs)
194 
195  def Accuracy(self, *args, **kwargs):
196  return brew.accuracy(self, *args, **kwargs)
197 
198  def MaxPool(self, *args, **kwargs):
199  return brew.max_pool(
200  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
201  )
202 
203  def MaxPoolWithIndex(self, *args, **kwargs):
204  return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
205 
206  def AveragePool(self, *args, **kwargs):
207  return brew.average_pool(
208  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
209  )
210 
211  @property
212  def XavierInit(self):
213  return ('XavierFill', {})
214 
215  def ConstantInit(self, value):
216  return ('ConstantFill', dict(value=value))
217 
218  @property
219  def MSRAInit(self):
220  return ('MSRAFill', {})
221 
222  @property
223  def ZeroInit(self):
224  return ('ConstantFill', {})
225 
226  def AddWeightDecay(self, weight_decay):
227  return brew.add_weight_decay(self, weight_decay)
228 
229  @property
230  def CPU(self):
231  device_option = caffe2_pb2.DeviceOption()
232  device_option.device_type = caffe2_pb2.CPU
233  return device_option
234 
235  @property
236  def GPU(self, gpu_id=0):
237  device_option = caffe2_pb2.DeviceOption()
238  device_option.device_type = workspace.GpuDeviceType
239  device_option.device_id = gpu_id
240  return device_option
def DepthConcat(self, args, kwargs)
Definition: cnn.py:181
def Concat(self, args, kwargs)
Definition: cnn.py:178