Caffe2 - C++ API
A deep learning, cross platform ML framework
locally_connected_op.h
1 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
2 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
3 
4 #include <vector>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/conv_op_shared.h"
9 #include "caffe2/operators/conv_pool_op_base.h"
10 #include "caffe2/operators/locally_connected_op_util.h"
11 
12 namespace caffe2 {
13 
14 template <typename T, class Context>
15 class LocallyConnectedOp final : public ConvPoolOpBase<Context> {
16  public:
17  USE_CONV_POOL_BASE_FUNCTIONS(Context);
18 
19  template <class... Args>
20  explicit LocallyConnectedOp(Args&&... args)
21  : ConvPoolOpBase<Context>(std::forward<Args>(args)...) {
22  // Since this is the default locally connected implementation, we will
23  // use CAFFE_ENFORCE instead of OPERATOR_NEEDS_FEATURE.
24  CAFFE_ENFORCE(
25  group_ == 1 || order_ == StorageOrder::NCHW,
26  "Group locally connected only supports NCHW order right now.");
27  }
28 
29  ~LocallyConnectedOp() = default;
30 
31  bool RunOnDeviceWithOrderNCHW() override;
32  bool RunOnDeviceWithOrderNHWC() override;
33 
34  private:
35  void RunOnDeviceWithOrderNCHWImpl(
36  const lc_op_util::ShapeParams& shape,
37  const T* X_data,
38  const T* filter_data,
39  const T* bias_data,
40  T* Y_data,
41  Tensor* column_buffer,
42  Tensor* column_transposed_buffer,
43  Tensor* output_buffer);
44 
45  void RunOnDeviceWithOrderNHWCImpl(
46  const lc_op_util::ShapeParams& shape,
47  const T* X_data,
48  const T* filter_data,
49  const T* bias_data,
50  T* Y_data,
51  Tensor* column_buffer,
52  Tensor* column_transposed_buffer,
53  Tensor* Y_transposed_buffer);
54 
55  Tensor bias_multiplier_{Context::GetDeviceType()};
56 
57  // Buffer.
58  Tensor column_buffer_{Context::GetDeviceType()};
59  Tensor column_transposed_buffer_{Context::GetDeviceType()};
60  Tensor Y_transposed_buffer_{Context::GetDeviceType()};
61 
62  // Input: X, W, b
63  // Output: Y
64  INPUT_TAGS(INPUT, FILTER, BIAS);
65 };
66 
67 template <typename T, class Context>
68 class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
69  public:
70  USE_CONV_POOL_BASE_FUNCTIONS(Context);
71 
72  template <class... Args>
73  explicit LocallyConnectedGradientOp(Args&&... args)
74  : ConvPoolOpBase<Context>(std::forward<Args>(args)...),
75  OP_SINGLE_ARG(bool, "no_bias", no_bias_, false) {
76  CAFFE_ENFORCE(
77  !(no_bias_ && OutputSize() == 3),
78  "If bias is not present, you should not have 3 grad output.");
79  CAFFE_ENFORCE(
80  group_ == 1 || order_ == StorageOrder::NCHW,
81  "Group locally connected only supports NCHW order right now.");
82  }
83 
84  ~LocallyConnectedGradientOp() = default;
85 
86  bool RunOnDeviceWithOrderNCHW() override;
87  bool RunOnDeviceWithOrderNHWC() override;
88 
89  private:
90  void RunOnDeviceWithOrderNCHWImpl(
91  const lc_op_util::ShapeParams& shape,
92  const T* X_data,
93  const T* filter_data,
94  const T* dY_data,
95  T* dfilter_data,
96  T* dX_data,
97  T* dbias_data,
98  Tensor* column_buffer,
99  Tensor* column_transposed_buffer,
100  Tensor* dY_transposed_buffer);
101 
102  void RunOnDeviceWithOrderNHWCImpl(
103  const lc_op_util::ShapeParams& shape,
104  const T* X_data,
105  const T* filter_data,
106  const T* dY_data,
107  T* dfilter_data,
108  T* dX_data,
109  T* dbias_data,
110  Tensor* column_buffer,
111  Tensor* column_transposed_buffer,
112  Tensor* dY_transposed_buffer);
113 
114  const bool no_bias_;
115 
116  Tensor bias_multiplier_{Context::GetDeviceType()};
117 
118  // Buffer.
119  Tensor column_buffer_{Context::GetDeviceType()};
120  Tensor column_transposed_buffer_{Context::GetDeviceType()};
121  Tensor dY_transposed_buffer_{Context::GetDeviceType()};
122 
123  // input: X, W, dY
124  // output: dW, db, and optionally dX
125  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
126  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
127 };
128 
129 } // namespace caffe2
130 
131 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13