Caffe2 - C++ API
A deep learning, cross platform ML framework
im2col_op.h
1 #ifndef CAFFE2_OPERATORS_IM2COL_OP_H_
2 #define CAFFE2_OPERATORS_IM2COL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class Im2ColOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  template <class... Args>
16  explicit Im2ColOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  pad_(this->template GetSingleArgument<int>("pad", 0)),
19  kernel_h_(this->template GetSingleArgument<int>(
20  "kernel_h",
21  this->template GetSingleArgument<int>("kernel", 0))),
22  kernel_w_(this->template GetSingleArgument<int>(
23  "kernel_w",
24  this->template GetSingleArgument<int>("kernel", 0))),
25  dilation_h_(this->template GetSingleArgument<int>(
26  "dilation_h",
27  this->template GetSingleArgument<int>("dilation", 1))),
28  dilation_w_(this->template GetSingleArgument<int>(
29  "dilation_w",
30  this->template GetSingleArgument<int>("dilation", 1))),
31  stride_h_(this->template GetSingleArgument<int>(
32  "stride_h",
33  this->template GetSingleArgument<int>("stride", 1))),
34  stride_w_(this->template GetSingleArgument<int>(
35  "stride_w",
36  this->template GetSingleArgument<int>("stride", 1))),
37  order_(StringToStorageOrder(
38  this->template GetSingleArgument<string>("order", "NCHW"))) {
39  CAFFE_ENFORCE(kernel_h_ > 0);
40  CAFFE_ENFORCE(kernel_w_ > 0);
41  CAFFE_ENFORCE(dilation_h_ > 0);
42  CAFFE_ENFORCE(dilation_w_ > 0);
43  CAFFE_ENFORCE(stride_h_ > 0);
44  CAFFE_ENFORCE(stride_w_ > 0);
45  CAFFE_ENFORCE(pad_ >= 0);
46  }
47 
48  bool RunOnDevice() override {
49  auto& X = Input(0);
50 
51  CAFFE_ENFORCE(4 == X.dim());
52 
53  int N = 0, C = 0, H = 0, W = 0;
54  switch (order_) {
55  case StorageOrder::NCHW:
56  N = X.dim32(0);
57  C = X.dim32(1);
58  H = X.dim32(2);
59  W = X.dim32(3);
60  break;
61  case StorageOrder::NHWC:
62  N = X.dim32(0);
63  H = X.dim32(1);
64  W = X.dim32(2);
65  C = X.dim32(3);
66  break;
67  default:
68  CAFFE_THROW("Unknown storage order: ", order_);
69  }
70 
71  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
72  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
73  CAFFE_ENFORCE(H >= dkernel_h);
74  CAFFE_ENFORCE(W >= dkernel_w);
75  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
76  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
77 
78  switch (order_) {
79  case StorageOrder::NCHW: {
80  auto* Y = Output(
81  0,
82  std::vector<int64_t>{N, C * kernel_h_ * kernel_w_, out_h, out_w},
83  at::dtype<T>());
84 
85  const size_t dx = X.numel() / N;
86  const size_t dy = Y->numel() / N;
87  for (int n = 0; n < N; ++n) {
88  const auto* xdata = X.template data<T>() + (n * dx);
89  auto* ydata = Y->template mutable_data<T>() + (n * dy);
90  math::Im2Col<T, Context, StorageOrder::NCHW>(
91  C,
92  H,
93  W,
94  kernel_h_,
95  kernel_w_,
96  dilation_h_,
97  dilation_w_,
98  pad_,
99  pad_,
100  pad_,
101  pad_,
102  stride_h_,
103  stride_w_,
104  xdata,
105  ydata,
106  &context_);
107  }
108  }; break;
109  case StorageOrder::NHWC: {
110  auto* Y = Output(
111  0,
112  std::vector<int64_t>{N, out_h, out_w, kernel_h_ * kernel_w_ * C},
113  at::dtype<T>());
114 
115  const size_t dx = X.numel() / N;
116  const size_t dy = Y->numel() / N;
117  for (int n = 0; n < N; ++n) {
118  const auto* xdata = X.template data<T>() + (n * dx);
119  auto* ydata = Y->template mutable_data<T>() + (n * dy);
120  math::Im2Col<T, Context, StorageOrder::NHWC>(
121  C,
122  H,
123  W,
124  kernel_h_,
125  kernel_w_,
126  dilation_h_,
127  dilation_w_,
128  pad_,
129  pad_,
130  pad_,
131  pad_,
132  stride_h_,
133  stride_w_,
134  xdata,
135  ydata,
136  &context_);
137  }
138  }; break;
139  default:
140  CAFFE_THROW("Unknown storage order: ", order_);
141  }
142 
143  return true;
144  }
145 
146  private:
147  int pad_;
148  int kernel_h_;
149  int kernel_w_;
150  int dilation_h_;
151  int dilation_w_;
152  int stride_h_;
153  int stride_w_;
154  StorageOrder order_;
155 };
156 
157 template <typename T, class Context>
158 class Col2ImOp final : public Operator<Context> {
159  public:
160  USE_OPERATOR_CONTEXT_FUNCTIONS;
161  template <class... Args>
162  explicit Col2ImOp(Args&&... args)
163  : Operator<Context>(std::forward<Args>(args)...),
164  pad_(this->template GetSingleArgument<int>("pad", 0)),
165  kernel_h_(this->template GetSingleArgument<int>(
166  "kernel_h",
167  this->template GetSingleArgument<int>("kernel", 0))),
168  kernel_w_(this->template GetSingleArgument<int>(
169  "kernel_w",
170  this->template GetSingleArgument<int>("kernel", 0))),
171  dilation_h_(this->template GetSingleArgument<int>(
172  "dilation_h",
173  this->template GetSingleArgument<int>("dilation", 1))),
174  dilation_w_(this->template GetSingleArgument<int>(
175  "dilation_w",
176  this->template GetSingleArgument<int>("dilation", 1))),
177  stride_h_(this->template GetSingleArgument<int>(
178  "stride_h",
179  this->template GetSingleArgument<int>("stride", 1))),
180  stride_w_(this->template GetSingleArgument<int>(
181  "stride_w",
182  this->template GetSingleArgument<int>("stride", 1))),
183  order_(StringToStorageOrder(
184  this->template GetSingleArgument<string>("order", "NCHW"))) {
185  CAFFE_ENFORCE(kernel_h_ > 0);
186  CAFFE_ENFORCE(kernel_w_ > 0);
187  CAFFE_ENFORCE(dilation_h_ > 0);
188  CAFFE_ENFORCE(dilation_w_ > 0);
189  CAFFE_ENFORCE(stride_h_ > 0);
190  CAFFE_ENFORCE(stride_w_ > 0);
191  CAFFE_ENFORCE(pad_ >= 0);
192  }
193 
194  bool RunOnDevice() override {
195  auto& X = Input(0);
196  auto& Z = Input(1);
197 
198  auto* Y = Output(0, Z.sizes(), at::dtype<T>());
199  CAFFE_ENFORCE(4 == Y->dim());
200 
201  int N = 0, C = 0, H = 0, W = 0;
202  switch (order_) {
203  case StorageOrder::NCHW:
204  N = Y->dim32(0);
205  C = Y->dim32(1);
206  H = Y->dim32(2);
207  W = Y->dim32(3);
208  break;
209  case StorageOrder::NHWC:
210  N = Y->dim32(0);
211  H = Y->dim32(1);
212  W = Y->dim32(2);
213  C = Y->dim32(3);
214  break;
215  default:
216  CAFFE_THROW("Unknown storage order: ", order_);
217  }
218 
219  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
220  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
221  CAFFE_ENFORCE(H >= dkernel_h);
222  CAFFE_ENFORCE(W >= dkernel_w);
223  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
224  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
225  CAFFE_ENFORCE(X.numel() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);
226 
227  const size_t dx = X.numel() / N;
228  const size_t dy = Y->numel() / N;
229 
230  // could template-specialize this, but it's test code...
231  switch (order_) {
232  case StorageOrder::NCHW: {
233  for (int n = 0; n < N; ++n) {
234  const auto* xdata = X.template data<T>() + (n * dx);
235  auto* ydata = Y->template mutable_data<T>() + (n * dy);
236  math::Col2Im<T, Context, StorageOrder::NCHW>(
237  C,
238  H,
239  W,
240  kernel_h_,
241  kernel_w_,
242  dilation_h_,
243  dilation_w_,
244  pad_,
245  pad_,
246  pad_,
247  pad_,
248  stride_h_,
249  stride_w_,
250  xdata,
251  ydata,
252  &context_);
253  }
254  }; break;
255  case StorageOrder::NHWC: {
256  for (int n = 0; n < N; ++n) {
257  const auto* xdata = X.template data<T>() + (n * dx);
258  auto* ydata = Y->template mutable_data<T>() + (n * dy);
259  math::Col2Im<T, Context, StorageOrder::NHWC>(
260  C,
261  H,
262  W,
263  kernel_h_,
264  kernel_w_,
265  dilation_h_,
266  dilation_w_,
267  pad_,
268  pad_,
269  pad_,
270  pad_,
271  stride_h_,
272  stride_w_,
273  xdata,
274  ydata,
275  &context_);
276  }
277  }; break;
278  default:
279  CAFFE_THROW("Unknown storage order: ", order_);
280  }
281 
282  return true;
283  }
284 
285  private:
286  int pad_;
287  int kernel_h_;
288  int kernel_w_;
289  int dilation_h_;
290  int dilation_w_;
291  int stride_h_;
292  int stride_w_;
293  StorageOrder order_;
294 };
295 
296 } // namespace caffe2
297 
298 #endif // CAFFE2_OPERATORS_IM2COL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64