Caffe2 - C++ API
A deep learning, cross platform ML framework
im2col_op.h
1 
17 #ifndef CAFFE2_OPERATORS_IM2COL_OP_H_
18 #define CAFFE2_OPERATORS_IM2COL_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class Im2ColOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  Im2ColOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  pad_(OperatorBase::GetSingleArgument<int>("pad", 0)),
34  kernel_h_(OperatorBase::GetSingleArgument<int>(
35  "kernel_h",
36  OperatorBase::GetSingleArgument<int>("kernel", 0))),
37  kernel_w_(OperatorBase::GetSingleArgument<int>(
38  "kernel_w",
39  OperatorBase::GetSingleArgument<int>("kernel", 0))),
40  dilation_h_(OperatorBase::GetSingleArgument<int>(
41  "dilation_h",
42  OperatorBase::GetSingleArgument<int>("dilation", 1))),
43  dilation_w_(OperatorBase::GetSingleArgument<int>(
44  "dilation_w",
45  OperatorBase::GetSingleArgument<int>("dilation", 1))),
46  stride_h_(OperatorBase::GetSingleArgument<int>(
47  "stride_h",
48  OperatorBase::GetSingleArgument<int>("stride", 1))),
49  stride_w_(OperatorBase::GetSingleArgument<int>(
50  "stride_w",
51  OperatorBase::GetSingleArgument<int>("stride", 1))),
52  order_(StringToStorageOrder(
53  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
54  CAFFE_ENFORCE(kernel_h_ > 0);
55  CAFFE_ENFORCE(kernel_w_ > 0);
56  CAFFE_ENFORCE(dilation_h_ > 0);
57  CAFFE_ENFORCE(dilation_w_ > 0);
58  CAFFE_ENFORCE(stride_h_ > 0);
59  CAFFE_ENFORCE(stride_w_ > 0);
60  CAFFE_ENFORCE(pad_ >= 0);
61  }
62 
63  bool RunOnDevice() override {
64  auto& X = Input(0);
65  auto* Y = Output(0);
66  CAFFE_ENFORCE(4 == X.ndim());
67 
68  int N = 0, C = 0, H = 0, W = 0;
69  switch (order_) {
70  case StorageOrder::NCHW:
71  N = X.dim32(0);
72  C = X.dim32(1);
73  H = X.dim32(2);
74  W = X.dim32(3);
75  break;
76  case StorageOrder::NHWC:
77  N = X.dim32(0);
78  H = X.dim32(1);
79  W = X.dim32(2);
80  C = X.dim32(3);
81  break;
82  default:
83  CAFFE_THROW("Unknown storage order: ", order_);
84  }
85 
86  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
87  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
88  CAFFE_ENFORCE(H >= dkernel_h);
89  CAFFE_ENFORCE(W >= dkernel_w);
90  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
91  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
92 
93  switch (order_) {
94  case StorageOrder::NCHW: {
95  Y->Resize(
96  std::vector<TIndex>{N, C * kernel_h_ * kernel_w_, out_h, out_w});
97 
98  const size_t dx = X.size() / N;
99  const size_t dy = Y->size() / N;
100  for (int n = 0; n < N; ++n) {
101  const auto* xdata = X.template data<T>() + (n * dx);
102  auto* ydata = Y->template mutable_data<T>() + (n * dy);
103  math::Im2col<T, Context, StorageOrder::NCHW>(
104  xdata,
105  C,
106  H,
107  W,
108  kernel_h_,
109  kernel_w_,
110  dilation_h_,
111  dilation_w_,
112  pad_,
113  pad_,
114  pad_,
115  pad_,
116  stride_h_,
117  stride_w_,
118  ydata,
119  &context_);
120  }
121  }; break;
122  case StorageOrder::NHWC: {
123  Y->Resize(
124  std::vector<TIndex>{N, out_h, out_w, kernel_h_ * kernel_w_ * C});
125 
126  const size_t dx = X.size() / N;
127  const size_t dy = Y->size() / N;
128  for (int n = 0; n < N; ++n) {
129  const auto* xdata = X.template data<T>() + (n * dx);
130  auto* ydata = Y->template mutable_data<T>() + (n * dy);
131  math::Im2col<T, Context, StorageOrder::NHWC>(
132  xdata,
133  C,
134  H,
135  W,
136  kernel_h_,
137  kernel_w_,
138  dilation_h_,
139  dilation_w_,
140  pad_,
141  pad_,
142  pad_,
143  pad_,
144  stride_h_,
145  stride_w_,
146  ydata,
147  &context_);
148  }
149  }; break;
150  default:
151  CAFFE_THROW("Unknown storage order: ", order_);
152  }
153 
154  return true;
155  }
156 
157  private:
158  int pad_;
159  int kernel_h_;
160  int kernel_w_;
161  int dilation_h_;
162  int dilation_w_;
163  int stride_h_;
164  int stride_w_;
165  StorageOrder order_;
166 };
167 
168 template <typename T, class Context>
169 class Col2ImOp final : public Operator<Context> {
170  public:
171  USE_OPERATOR_CONTEXT_FUNCTIONS;
172  Col2ImOp(const OperatorDef& operator_def, Workspace* ws)
173  : Operator<Context>(operator_def, ws),
174  pad_(OperatorBase::GetSingleArgument<int>("pad", 0)),
175  kernel_h_(OperatorBase::GetSingleArgument<int>(
176  "kernel_h",
177  OperatorBase::GetSingleArgument<int>("kernel", 0))),
178  kernel_w_(OperatorBase::GetSingleArgument<int>(
179  "kernel_w",
180  OperatorBase::GetSingleArgument<int>("kernel", 0))),
181  dilation_h_(OperatorBase::GetSingleArgument<int>(
182  "dilation_h",
183  OperatorBase::GetSingleArgument<int>("dilation", 1))),
184  dilation_w_(OperatorBase::GetSingleArgument<int>(
185  "dilation_w",
186  OperatorBase::GetSingleArgument<int>("dilation", 1))),
187  stride_h_(OperatorBase::GetSingleArgument<int>(
188  "stride_h",
189  OperatorBase::GetSingleArgument<int>("stride", 1))),
190  stride_w_(OperatorBase::GetSingleArgument<int>(
191  "stride_w",
192  OperatorBase::GetSingleArgument<int>("stride", 1))),
193  order_(StringToStorageOrder(
194  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
195  CAFFE_ENFORCE(kernel_h_ > 0);
196  CAFFE_ENFORCE(kernel_w_ > 0);
197  CAFFE_ENFORCE(dilation_h_ > 0);
198  CAFFE_ENFORCE(dilation_w_ > 0);
199  CAFFE_ENFORCE(stride_h_ > 0);
200  CAFFE_ENFORCE(stride_w_ > 0);
201  CAFFE_ENFORCE(pad_ >= 0);
202  }
203 
204  bool RunOnDevice() override {
205  auto& X = Input(0);
206  auto& Z = Input(1);
207  auto* Y = Output(0);
208  Y->ResizeLike(Z);
209  CAFFE_ENFORCE(4 == Y->ndim());
210 
211  int N = 0, C = 0, H = 0, W = 0;
212  switch (order_) {
213  case StorageOrder::NCHW:
214  N = Y->dim32(0);
215  C = Y->dim32(1);
216  H = Y->dim32(2);
217  W = Y->dim32(3);
218  break;
219  case StorageOrder::NHWC:
220  N = Y->dim32(0);
221  H = Y->dim32(1);
222  W = Y->dim32(2);
223  C = Y->dim32(3);
224  break;
225  default:
226  CAFFE_THROW("Unknown storage order: ", order_);
227  }
228 
229  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
230  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
231  CAFFE_ENFORCE(H >= dkernel_h);
232  CAFFE_ENFORCE(W >= dkernel_w);
233  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
234  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
235  CAFFE_ENFORCE(X.size() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);
236 
237  const size_t dx = X.size() / N;
238  const size_t dy = Y->size() / N;
239 
240  // could template-specialize this, but it's test code...
241  switch (order_) {
242  case StorageOrder::NCHW: {
243  for (int n = 0; n < N; ++n) {
244  const auto* xdata = X.template data<T>() + (n * dx);
245  auto* ydata = Y->template mutable_data<T>() + (n * dy);
246  math::Col2im<T, Context, StorageOrder::NCHW>(
247  xdata,
248  C,
249  H,
250  W,
251  kernel_h_,
252  kernel_w_,
253  dilation_h_,
254  dilation_w_,
255  pad_,
256  pad_,
257  pad_,
258  pad_,
259  stride_h_,
260  stride_w_,
261  ydata,
262  &context_);
263  }
264  }; break;
265  case StorageOrder::NHWC: {
266  for (int n = 0; n < N; ++n) {
267  const auto* xdata = X.template data<T>() + (n * dx);
268  auto* ydata = Y->template mutable_data<T>() + (n * dy);
269  math::Col2im<T, Context, StorageOrder::NHWC>(
270  xdata,
271  C,
272  H,
273  W,
274  kernel_h_,
275  kernel_w_,
276  dilation_h_,
277  dilation_w_,
278  pad_,
279  pad_,
280  pad_,
281  pad_,
282  stride_h_,
283  stride_w_,
284  ydata,
285  &context_);
286  }
287  }; break;
288  default:
289  CAFFE_THROW("Unknown storage order: ", order_);
290  }
291 
292  return true;
293  }
294 
295  private:
296  int pad_;
297  int kernel_h_;
298  int kernel_w_;
299  int dilation_h_;
300  int dilation_w_;
301  int stride_h_;
302  int stride_w_;
303  StorageOrder order_;
304 };
305 
306 } // namespace caffe2
307 
308 #endif // CAFFE2_OPERATORS_IM2COL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.