Caffe2 - C++ API
A deep learning, cross platform ML framework
max_pool_with_index.h
1 
17 #ifndef CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
18 #define CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
19 
20 #include <cfloat>
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/context_gpu.h"
23 #include "caffe2/core/logging.h"
24 #include "caffe2/core/operator.h"
25 #include "caffe2/operators/conv_pool_op_base.h"
26 #include "caffe2/operators/pool_op.h"
27 #include "caffe2/utils/math.h"
28 
29 namespace caffe2 {
30 
31 class MaxPoolWithIndexOp final : public ConvPoolOpBase<CUDAContext> {
32  public:
33  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
34  MaxPoolWithIndexOp(const OperatorDef& operator_def, Workspace* ws)
35  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
36  ~MaxPoolWithIndexOp() {}
37 
38  template <typename T>
39  bool DoRunWithType();
40 
41  bool RunOnDevice() override;
42 
43  // Input: X
44  // Output: Y, mask
45 };
46 
47 class MaxPoolWithIndexGradientOp final : public ConvPoolOpBase<CUDAContext> {
48  public:
49  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
50  MaxPoolWithIndexGradientOp(const OperatorDef& operator_def, Workspace* ws)
51  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
53 
54  template <typename T>
55  bool DoRunWithType();
56 
57  bool RunOnDevice() override;
58 
59  // Input: X, dY, mask
60  // Output: dX
61 };
62 
63 }; // namespace caffe2
64 
65 #endif // CAFFE2_OPERATORS_MAX_POOL_WITH_INDEX_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.