Caffe2 - C++ API
A deep learning, cross platform ML framework
locally_connected_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
18 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
19 
20 #include <vector>
21 
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/operators/conv_op_shared.h"
25 #include "caffe2/operators/conv_pool_op_base.h"
26 
27 namespace caffe2 {
28 
29 template <typename T, class Context>
30 class LocallyConnectedOp final : public ConvPoolOpBase<Context> {
31  public:
32  USE_CONV_POOL_BASE_FUNCTIONS(Context);
33 
34  LocallyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
35  : ConvPoolOpBase<Context>(operator_def, ws) {
36  // Since this is the default locally connected implementation, we will
37  // use CAFFE_ENFORCE instead of OPERATOR_NEEDS_FEATURE.
38  CAFFE_ENFORCE(
39  group_ == 1 || order_ == StorageOrder::NCHW,
40  "Group locally connected only supports NCHW order right now.");
41  }
42 
43  ~LocallyConnectedOp() = default;
44 
45  bool RunOnDeviceWithOrderNCHW() override;
46  bool RunOnDeviceWithOrderNHWC() override;
47 
48  private:
49  struct ShapeParams {
50  int N;
51  int C;
52  int M;
53  int input_image_size;
54  int output_image_size;
55  int kernel_dim;
56  std::vector<int> input_image_dims;
57  std::vector<int> column_dims;
58  std::vector<int> column_transposed_dims;
59  std::vector<int> Y_transposed_dims;
60  };
61 
62  void RunOnDeviceWithOrderNCHWImpl(
63  const ShapeParams& shape,
64  const T* X_data,
65  const T* filter_data,
66  const T* bias_data,
67  T* Y_data,
68  Tensor<Context>* column_buffer,
69  Tensor<Context>* column_transposed_buffer,
70  Tensor<Context>* output_buffer);
71 
72  void RunOnDeviceWithOrderNHWCImpl(
73  const ShapeParams& shape,
74  const T* X_data,
75  const T* filter_data,
76  const T* bias_data,
77  T* Y_data,
78  Tensor<Context>* column_buffer,
79  Tensor<Context>* column_transposed_buffer,
80  Tensor<Context>* Y_transposed_buffer);
81 
82  void SetColumnBufferShape(
83  const int N,
84  const int C,
85  const int kernel_dim,
86  const std::vector<int>& output_image_dims,
87  std::vector<int>* column_dims,
88  std::vector<int>* column_transposed_dims);
89 
90  void SetYTranposedBufferShape(
91  const std::vector<int>& Y_dims,
92  std::vector<int>* Y_transposed_dims);
93 
94  Tensor<Context> bias_multiplier_;
95 
96  // Buffer.
97  Tensor<Context> column_buffer_;
98  Tensor<Context> column_transposed_buffer_;
99  Tensor<Context> Y_transposed_buffer_;
100 
101  // Dims devices.
102  Tensor<Context> input_dims_device_;
103  Tensor<Context> column_dims_device_;
104  Tensor<Context> column_transposed_dims_device_;
105  Tensor<Context> column_axes_device_;
106  Tensor<Context> Y_dims_device_;
107  Tensor<Context> Y_transposed_dims_device_;
108  Tensor<Context> Y_transposed_axes_device_;
109 
110  // Input: X, W, b
111  // Output: Y
112  INPUT_TAGS(INPUT, FILTER, BIAS);
113 };
114 
115 template <typename T, class Context>
116 class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
117  public:
118  USE_CONV_POOL_BASE_FUNCTIONS(Context);
119 
120  LocallyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
121  : ConvPoolOpBase<Context>(operator_def, ws),
122  no_bias_(OperatorBase::GetSingleArgument<int>("no_bias", 0)) {
123  CAFFE_ENFORCE(
124  !(no_bias_ && OutputSize() == 3),
125  "If bias is not present, you should not have 3 grad output.");
126  CAFFE_ENFORCE(
127  group_ == 1 || order_ == StorageOrder::NCHW,
128  "Group locally connected only supports NCHW order right now.");
129  }
130 
131  ~LocallyConnectedGradientOp() = default;
132 
133  bool RunOnDeviceWithOrderNCHW() override;
134  bool RunOnDeviceWithOrderNHWC() override;
135 
136  private:
137  struct ShapeParams {
138  int N;
139  int C;
140  int M;
141  int input_image_size;
142  int output_image_size;
143  int kernel_dim;
144  std::vector<int> input_image_dims;
145  std::vector<int> column_dims;
146  std::vector<int> column_transposed_dims;
147  std::vector<int> dY_transposed_dims;
148  };
149 
150  void RunOnDeviceWithOrderNCHWImpl(
151  const ShapeParams& shape,
152  const T* X_data,
153  const T* filter_data,
154  const T* dY_data,
155  T* dfilter_data,
156  T* dX_data,
157  T* dbias_data,
158  Tensor<Context>* column_buffer,
159  Tensor<Context>* column_transposed_buffer,
160  Tensor<Context>* dY_transposed_buffer);
161 
162  void RunOnDeviceWithOrderNHWCImpl(
163  const ShapeParams& shape,
164  const T* X_data,
165  const T* filter_data,
166  const T* dY_data,
167  T* dfilter_data,
168  T* dX_data,
169  T* dbias_data,
170  Tensor<Context>* column_buffer,
171  Tensor<Context>* column_transposed_buffer,
172  Tensor<Context>* dY_transposed_buffer);
173 
174  void SetColumnBufferShape(
175  const int N,
176  const int C,
177  const int kernel_dim,
178  const std::vector<int>& output_image_dims,
179  std::vector<int>* column_dims,
180  std::vector<int>* column_transposed_dims);
181 
182  void SetDYTranposedBufferShape(
183  const std::vector<int>& dY_dims,
184  std::vector<int>* dY_transposed_dims);
185 
186  bool no_bias_;
187 
188  Tensor<Context> bias_multiplier_;
189 
190  // Buffer.
191  Tensor<Context> column_buffer_;
192  Tensor<Context> column_transposed_buffer_;
193  Tensor<Context> dY_transposed_buffer_;
194 
195  // Dims devices.
196  Tensor<Context> input_dims_device_;
197  Tensor<Context> column_dims_device_;
198  Tensor<Context> column_transposed_dims_device_;
199  Tensor<Context> column_axes_device_;
200  Tensor<Context> column_transposed_axes_device_;
201  Tensor<Context> dY_dims_device_;
202  Tensor<Context> dY_transposed_dims_device_;
203  Tensor<Context> dY_axes_device_;
204 
205  // input: X, W, dY
206  // output: dW, db, and optionally dX
207  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
208  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
209 };
210 
211 } // namespace caffe2
212 
213 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.