Caffe2 - Python API
A deep learning, cross platform ML framework
pooling.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package pooling
17 # Module caffe2.python.helpers.pooling
18 ## @package fc
19 # Module caffe2.python.helpers.pooling
20 from __future__ import absolute_import
21 from __future__ import division
22 from __future__ import print_function
23 from __future__ import unicode_literals
24 
25 
26 def max_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
27  """Max pooling"""
28  if use_cudnn:
29  kwargs['engine'] = 'CUDNN'
30  return model.net.MaxPool(blob_in, blob_out, order=order, **kwargs)
31 
32 
33 def average_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW",
34  **kwargs):
35  """Average pooling"""
36  if use_cudnn:
37  kwargs['engine'] = 'CUDNN'
38  return model.net.AveragePool(
39  blob_in,
40  blob_out,
41  order=order,
42  **kwargs
43  )
44 
45 
46 def max_pool_with_index(model, blob_in, blob_out, order="NCHW", **kwargs):
47  """Max pooling with an explicit index of max position"""
48  return model.net.MaxPoolWithIndex(
49  blob_in,
50  [blob_out, blob_out + "_index"],
51  order=order,
52  **kwargs
53  )[0]