Caffe2 - Python API
A deep learning, cross platform ML framework
pooling.py
1 ## @package pooling
2 # Module caffe2.python.helpers.pooling
3 ## @package fc
4 # Module caffe2.python.helpers.pooling
5 from __future__ import absolute_import
6 from __future__ import division
7 from __future__ import print_function
8 from __future__ import unicode_literals
9 
10 
11 def max_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
12  """Max pooling"""
13  if use_cudnn:
14  kwargs['engine'] = 'CUDNN'
15  return model.net.MaxPool(blob_in, blob_out, order=order, **kwargs)
16 
17 
18 def average_pool(model, blob_in, blob_out, use_cudnn=False, order="NCHW",
19  **kwargs):
20  """Average pooling"""
21  if use_cudnn:
22  kwargs['engine'] = 'CUDNN'
23  return model.net.AveragePool(
24  blob_in,
25  blob_out,
26  order=order,
27  **kwargs
28  )
29 
30 
31 def max_pool_with_index(model, blob_in, blob_out, order="NCHW", **kwargs):
32  """Max pooling with an explicit index of max position"""
33  return model.net.MaxPoolWithIndex(
34  blob_in,
35  [blob_out, blob_out + "_index"],
36  order=order,
37  **kwargs
38  )[0]