Caffe2 - C++ API
A deep learning, cross platform ML framework
max_pool_with_index_gpu.h
1 #pragma once
2 
3 #include <cfloat>
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/context_gpu.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/conv_pool_op_base.h"
9 #include "caffe2/operators/pool_op.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 class MaxPoolWithIndexOp final : public ConvPoolOpBase<CUDAContext> {
15  public:
16  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
17  MaxPoolWithIndexOp(const OperatorDef& operator_def, Workspace* ws)
18  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
19  ~MaxPoolWithIndexOp() {}
20 
21  template <typename T>
22  bool DoRunWithType();
23 
24  bool RunOnDevice() override;
25 
26  // Input: X
27  // Output: Y, mask
28 };
29 
30 class MaxPoolWithIndexGradientOp final : public ConvPoolOpBase<CUDAContext> {
31  public:
32  USE_CONV_POOL_BASE_FUNCTIONS(CUDAContext);
33  MaxPoolWithIndexGradientOp(const OperatorDef& operator_def, Workspace* ws)
34  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {}
36 
37  template <typename T>
38  bool DoRunWithType();
39 
40  bool RunOnDevice() override;
41 
42  // Input: X, dY, mask
43  // Output: dX
44 };
45 
46 }; // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13