Caffe2 - Python API
A deep learning, cross platform ML framework
cnn.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package cnn
17 # Module caffe2.python.cnn
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import brew
24 from caffe2.python.model_helper import ModelHelper
25 from caffe2.proto import caffe2_pb2
26 import logging
27 
28 
30  """A helper model so we can write CNN models more easily, without having to
31  manually define parameter initializations and operators separately.
32  """
33 
34  def __init__(self, order="NCHW", name=None,
35  use_cudnn=True, cudnn_exhaustive_search=False,
36  ws_nbytes_limit=None, init_params=True,
37  skip_sparse_optim=False,
38  param_model=None):
39  logging.warning(
40  "[====DEPRECATE WARNING====]: you are creating an "
41  "object from CNNModelHelper class which will be deprecated soon. "
42  "Please use ModelHelper object with brew module. For more "
43  "information, please refer to caffe2.ai and python/brew.py, "
44  "python/brew_test.py for more information."
45  )
46 
47  cnn_arg_scope = {
48  'order': order,
49  'use_cudnn': use_cudnn,
50  'cudnn_exhaustive_search': cudnn_exhaustive_search,
51  }
52  if ws_nbytes_limit:
53  cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
54  super(CNNModelHelper, self).__init__(
55  skip_sparse_optim=skip_sparse_optim,
56  name="CNN" if name is None else name,
57  init_params=init_params,
58  param_model=param_model,
59  arg_scope=cnn_arg_scope,
60  )
61 
62  self.order = order
63  self.use_cudnn = use_cudnn
64  self.cudnn_exhaustive_search = cudnn_exhaustive_search
65  self.ws_nbytes_limit = ws_nbytes_limit
66  if self.order != "NHWC" and self.order != "NCHW":
67  raise ValueError(
68  "Cannot understand the CNN storage order %s." % self.order
69  )
70 
71  def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
72  return brew.image_input(
73  self,
74  blob_in,
75  blob_out,
76  order=self.order,
77  use_gpu_transform=use_gpu_transform,
78  **kwargs
79  )
80 
81  def VideoInput(self, blob_in, blob_out, **kwargs):
82  return brew.video_input(
83  self,
84  blob_in,
85  blob_out,
86  **kwargs
87  )
88 
89  def PadImage(self, blob_in, blob_out, **kwargs):
90  # TODO(wyiming): remove this dummy helper later
91  self.net.PadImage(blob_in, blob_out, **kwargs)
92 
93  def ConvNd(self, *args, **kwargs):
94  return brew.conv_nd(
95  self,
96  *args,
97  use_cudnn=self.use_cudnn,
98  order=self.order,
99  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
100  ws_nbytes_limit=self.ws_nbytes_limit,
101  **kwargs
102  )
103 
104  def Conv(self, *args, **kwargs):
105  return brew.conv(
106  self,
107  *args,
108  use_cudnn=self.use_cudnn,
109  order=self.order,
110  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
111  ws_nbytes_limit=self.ws_nbytes_limit,
112  **kwargs
113  )
114 
115  def ConvTranspose(self, *args, **kwargs):
116  return brew.conv_transpose(
117  self,
118  *args,
119  use_cudnn=self.use_cudnn,
120  order=self.order,
121  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
122  ws_nbytes_limit=self.ws_nbytes_limit,
123  **kwargs
124  )
125 
126  def GroupConv(self, *args, **kwargs):
127  return brew.group_conv(
128  self,
129  *args,
130  use_cudnn=self.use_cudnn,
131  order=self.order,
132  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
133  ws_nbytes_limit=self.ws_nbytes_limit,
134  **kwargs
135  )
136 
137  def GroupConv_Deprecated(self, *args, **kwargs):
138  return brew.group_conv_deprecated(
139  self,
140  *args,
141  use_cudnn=self.use_cudnn,
142  order=self.order,
143  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
144  ws_nbytes_limit=self.ws_nbytes_limit,
145  **kwargs
146  )
147 
148  def FC(self, *args, **kwargs):
149  return brew.fc(self, *args, **kwargs)
150 
151  def PackedFC(self, *args, **kwargs):
152  return brew.packed_fc(self, *args, **kwargs)
153 
154  def FC_Prune(self, *args, **kwargs):
155  return brew.fc_prune(self, *args, **kwargs)
156 
157  def FC_Decomp(self, *args, **kwargs):
158  return brew.fc_decomp(self, *args, **kwargs)
159 
160  def FC_Sparse(self, *args, **kwargs):
161  return brew.fc_sparse(self, *args, **kwargs)
162 
163  def Dropout(self, *args, **kwargs):
164  return brew.dropout(
165  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
166  )
167 
168  def LRN(self, *args, **kwargs):
169  return brew.lrn(
170  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
171  )
172 
173  def Softmax(self, *args, **kwargs):
174  return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
175 
176  def SpatialBN(self, *args, **kwargs):
177  return brew.spatial_bn(self, *args, order=self.order, **kwargs)
178 
179  def InstanceNorm(self, *args, **kwargs):
180  return brew.instance_norm(self, *args, order=self.order, **kwargs)
181 
182  def Relu(self, *args, **kwargs):
183  return brew.relu(
184  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
185  )
186 
187  def PRelu(self, *args, **kwargs):
188  return brew.prelu(self, *args, **kwargs)
189 
190  def Concat(self, *args, **kwargs):
191  return brew.concat(self, *args, order=self.order, **kwargs)
192 
193  def DepthConcat(self, *args, **kwargs):
194  """The old depth concat function - we should move to use concat."""
195  print("DepthConcat is deprecated. use Concat instead.")
196  return self.Concat(*args, **kwargs)
197 
198  def Sum(self, *args, **kwargs):
199  return brew.sum(self, *args, **kwargs)
200 
201  def Transpose(self, *args, **kwargs):
202  return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
203 
204  def Iter(self, *args, **kwargs):
205  return brew.iter(self, *args, **kwargs)
206 
207  def Accuracy(self, *args, **kwargs):
208  return brew.accuracy(self, *args, **kwargs)
209 
210  def MaxPool(self, *args, **kwargs):
211  return brew.max_pool(
212  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
213  )
214 
215  def MaxPoolWithIndex(self, *args, **kwargs):
216  return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
217 
218  def AveragePool(self, *args, **kwargs):
219  return brew.average_pool(
220  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
221  )
222 
223  @property
224  def XavierInit(self):
225  return ('XavierFill', {})
226 
227  def ConstantInit(self, value):
228  return ('ConstantFill', dict(value=value))
229 
230  @property
231  def MSRAInit(self):
232  return ('MSRAFill', {})
233 
234  @property
235  def ZeroInit(self):
236  return ('ConstantFill', {})
237 
238  def AddWeightDecay(self, weight_decay):
239  return brew.add_weight_decay(self, weight_decay)
240 
241  @property
242  def CPU(self):
243  device_option = caffe2_pb2.DeviceOption()
244  device_option.device_type = caffe2_pb2.CPU
245  return device_option
246 
247  @property
248  def GPU(self, gpu_id=0):
249  device_option = caffe2_pb2.DeviceOption()
250  device_option.device_type = caffe2_pb2.CUDA
251  device_option.cuda_gpu_id = gpu_id
252  return device_option
def DepthConcat(self, args, kwargs)
Definition: cnn.py:193
def Concat(self, args, kwargs)
Definition: cnn.py:190