Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_op.h
1 
17 #ifndef CAFFE2_OPERATORS_CONV_OP_H_
18 #define CAFFE2_OPERATORS_CONV_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/operators/conv_op_shared.h"
23 #include "caffe2/operators/conv_pool_op_base.h"
24 
25 CAFFE2_DECLARE_bool(caffe2_force_shared_col_buffer);
26 
27 namespace caffe2 {
28 
29 template <typename T, class Context>
30 class ConvOp final : public ConvPoolOpBase<Context> {
31  public:
32  USE_CONV_POOL_BASE_FUNCTIONS(Context);
33  ConvOp(const OperatorDef& operator_def, Workspace* ws)
34  : ConvPoolOpBase<Context>(operator_def, ws) {
35  // Since this is the default convolution implementation, we will
36  // use CAFFE_ENFORCE instead of OPERATOR_NEEDS_FEATURE.
37  CAFFE_ENFORCE(
38  group_ == 1 || order_ == StorageOrder::NCHW,
39  "Group convolution only supports NCHW order right now.");
40 
41  // Create shared buffer mutex in the constructor
42  // to avoid race-condition in DAGNet.
43  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
44  createSharedBuffer<Context>(ws_);
45  }
46  }
47  ~ConvOp() {}
48 
49  bool RunOnDeviceWithOrderNCHW() override;
50  bool RunOnDeviceWithOrderNHWC() override;
51 
52  private:
53  Tensor<Context> col_buffer_;
54  Tensor<Context> bias_multiplier_;
55  Tensor<Context> img_shape_device_;
56  Tensor<Context> col_buffer_shape_device_;
57  // Input: X, W, b
58  // Output: Y
59  INPUT_TAGS(INPUT, FILTER, BIAS);
60 };
61 
62 template <typename T, class Context>
63 class ConvGradientOp final : public ConvPoolOpBase<Context> {
64  public:
65  USE_CONV_POOL_BASE_FUNCTIONS(Context);
66  ConvGradientOp(const OperatorDef& operator_def, Workspace* ws)
67  : ConvPoolOpBase<Context>(operator_def, ws),
68  no_bias_(OperatorBase::GetSingleArgument<int>("no_bias", 0)) {
69  CAFFE_ENFORCE(
70  !(no_bias_ && OutputSize() == 3),
71  "If bias is not present, you should not have 3 grad output.");
72  CAFFE_ENFORCE(
73  group_ == 1 || order_ == StorageOrder::NCHW,
74  "Group convolution only supports NCHW order right now.");
75  }
76  ~ConvGradientOp() {}
77 
78  bool RunOnDeviceWithOrderNCHW() override;
79  bool RunOnDeviceWithOrderNHWC() override;
80 
81  private:
82  Tensor<Context> col_buffer_;
83  Tensor<Context> bias_multiplier_;
84  Tensor<Context> img_shape_device_;
85  Tensor<Context> col_buffer_shape_device_;
86  bool no_bias_;
87  // input: X, W, dY
88  // output: dW, db, and optionally dX
89  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
90  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_CONV_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.