Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_pool_base_op.h
1 #ifndef CAFFE2_IDEEP_OPERATORS_CONV_POOL_BASE_OP_H_
2 #define CAFFE2_IDEEP_OPERATORS_CONV_POOL_BASE_OP_H_
3 
4 #include <vector>
5 
6 #include "caffe2/ideep/ideep_utils.h"
7 #include "caffe2/operators/conv_pool_op_base.h"
8 
9 namespace caffe2 {
10 
11 class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
12  public:
13  IDEEPConvPoolOpBase(const OperatorDef& operator_def, Workspace* ws)
14  : ConvPoolOpBase<IDEEPContext>(operator_def, ws) {
15  OPERATOR_NEEDS_FEATURE(
16  order_ == StorageOrder::NCHW, "Unsupported storage order.");
17  }
18  virtual ~IDEEPConvPoolOpBase() {}
19 
20  inline const ideep::tensor& Input(int index) {
21  return OperatorBase::template Input<ideep::tensor>(index);
22  }
23  inline ideep::tensor* Output(int index) {
24  return OperatorBase::template Output<ideep::tensor>(index);
25  }
26 
27  ideep::tensor::dims pad_tl() const {
28  return {pad_t(), pad_l()};
29  }
30 
31  ideep::tensor::dims pad_br() const {
32  return {pad_b(), pad_r()};
33  }
34 
35  ideep::tensor::dims CalcOutputDims(
36  const ideep::tensor& input,
37  int output_channel) {
38  CAFFE_ENFORCE(input.get_descriptor().get_size() > 0);
39  ideep::tensor::dims output_dims;
40  const auto input_dims = input.get_dims();
41  std::vector<std::int64_t> input_Tdims(
42  input_dims.cbegin(), input_dims.cend());
43  InferOutputSize(
44  input_Tdims,
45  output_channel,
46  order_,
47  global_pooling_,
48  legacy_pad_,
49  dilation_,
50  stride_,
51  &kernel_,
52  &pads_,
53  &output_dims);
54  return output_dims;
55  }
56 
57  bool RunOnDevice() override {
58  if (!global_pooling_) {
59  for (int dim = 0; dim < kernel_.size(); ++dim) {
60  CAFFE_ENFORCE_GT(kernel_[dim], 0);
61  }
62  }
63 
64  try {
65  return RunOnDeviceWithOrderNCHW();
66  } catch (ideep::error& e) {
67  LOG(ERROR) << "IDEEP error:" << e.message;
68  throw;
69  }
70  }
71 };
72 
73 #define USE_IDEEP_CONV_POOL_BASE_FUNCTIONS() \
74  USE_OPERATOR_BASE_FUNCTIONS; \
75  /* using override */ using IDEEPConvPoolOpBase::Input; \
76  /* using override */ using IDEEPConvPoolOpBase::Output;
77 
78 } // namespace caffe2
79 
80 #endif // CAFFE2_IDEEP_OPERATORS_CONV_POOL_BASE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13