Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_unpool_op_base.h
1 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
2 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/conv_op_shared.h"
8 #include "caffe2/operators/conv_pool_op_base.h"
9 #include "caffe2/proto/caffe2_legacy.pb.h"
10 #include "caffe2/utils/math.h"
11 
12 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
13 
14 namespace caffe2 {
15 
16 template <class Context>
17 class ConvTransposeUnpoolBase : public Operator<Context> {
18  public:
19  USE_OPERATOR_CONTEXT_FUNCTIONS;
20  explicit ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws),
22  legacy_pad_(
23  static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
24  "legacy_pad",
25  LegacyPadding::NOTSET))),
26  kernel_(this->template GetRepeatedArgument<int>("kernels")),
27  stride_(this->template GetRepeatedArgument<int>("strides")),
28  pads_(this->template GetRepeatedArgument<int>("pads")),
29  adj_(this->template GetRepeatedArgument<int>("adjs")),
30  order_(StringToStorageOrder(
31  this->template GetSingleArgument<string>("order", "NCHW"))),
32  shared_buffer_(
33  this->template GetSingleArgument<int>("shared_buffer", 0)),
34  ws_(ws) {
35  // For the padding, they should either be the legacy padding strategy
36  // (VALID or SAME), or an explicit, non-negative value.
37  if (legacy_pad_ == LegacyPadding::VALID ||
38  legacy_pad_ == LegacyPadding::SAME) {
39  CAFFE_ENFORCE(
41  "If you use legacy padding VALID or SAME, you should not specify "
42  "any specific padding values.");
43  }
44  // Get old arguments values.
45  if (OperatorBase::HasArgument("kernel")) {
46  kernel_.resize(2, this->template GetSingleArgument<int>("kernel", 0));
47  } else if (
48  OperatorBase::HasArgument("kernel_h") &&
49  OperatorBase::HasArgument("kernel_w")) {
50  kernel_.push_back(this->template GetSingleArgument<int>("kernel_h", 0));
51  kernel_.push_back(this->template GetSingleArgument<int>("kernel_w", 0));
52  }
53 
54  if (OperatorBase::HasArgument("stride")) {
55  stride_.resize(2, this->template GetSingleArgument<int>("stride", 0));
56  } else if (
57  OperatorBase::HasArgument("stride_h") &&
58  OperatorBase::HasArgument("stride_w")) {
59  stride_.push_back(this->template GetSingleArgument<int>("stride_h", 0));
60  stride_.push_back(this->template GetSingleArgument<int>("stride_w", 0));
61  }
62 
63  if (OperatorBase::HasArgument("adj")) {
64  adj_.resize(2, this->template GetSingleArgument<int>("adj", 0));
65  } else if (
66  OperatorBase::HasArgument("adj_h") &&
67  OperatorBase::HasArgument("adj_w")) {
68  adj_.push_back(this->template GetSingleArgument<int>("adj_h", 0));
69  adj_.push_back(this->template GetSingleArgument<int>("adj_w", 0));
70  }
71 
72  if (OperatorBase::HasArgument("pad")) {
73  CAFFE_ENFORCE(
74  legacy_pad_ != LegacyPadding::VALID &&
75  legacy_pad_ != LegacyPadding::SAME,
76  "If you use legacy padding VALID or SAME, you should not specify "
77  "any specific padding values.");
78  pads_.resize(4, this->template GetSingleArgument<int>("pad", 0));
79  } else if (
80  OperatorBase::HasArgument("pad_t") &&
81  OperatorBase::HasArgument("pad_l") &&
82  OperatorBase::HasArgument("pad_b") &&
83  OperatorBase::HasArgument("pad_r")) {
84  CAFFE_ENFORCE(
85  legacy_pad_ != LegacyPadding::VALID &&
86  legacy_pad_ != LegacyPadding::SAME,
87  "If you use legacy padding VALID or SAME, you should not specify "
88  "any specific padding values.");
89  pads_.push_back(this->template GetSingleArgument<int>("pad_t", 0));
90  pads_.push_back(this->template GetSingleArgument<int>("pad_l", 0));
91  pads_.push_back(this->template GetSingleArgument<int>("pad_b", 0));
92  pads_.push_back(this->template GetSingleArgument<int>("pad_r", 0));
93  }
94 
95  // Fill default values.
96  if (kernel_.size() == 0) {
97  kernel_.assign({0, 0});
98  }
99 
100  if (stride_.size() == 0) {
101  stride_.resize(kernel_.size(), 1);
102  }
103 
104  if (pads_.size() == 0) {
105  pads_.resize(kernel_.size() * 2, 0);
106  }
107 
108  if (adj_.size() == 0) {
109  adj_.resize(kernel_.size(), 0);
110  }
111 
112  CAFFE_ENFORCE_EQ(stride_.size(), kernel_.size());
113  CAFFE_ENFORCE_EQ(adj_.size(), kernel_.size());
114 
115  if (legacy_pad_ != LegacyPadding::VALID &&
116  legacy_pad_ != LegacyPadding::SAME) {
117  CAFFE_ENFORCE_EQ(pads_.size(), 2 * kernel_.size());
118  }
119 
120  for (int dim = 0; dim < kernel_.size(); ++dim) {
121  CAFFE_ENFORCE_GT(kernel_[dim], 0);
122  CAFFE_ENFORCE_GT(stride_[dim], 0);
123  CAFFE_ENFORCE_GE(adj_[dim], 0);
124  CAFFE_ENFORCE_LE(adj_[dim], stride_[dim]);
125  }
126 
127  // Create shared buffer mutex in the constructor
128  // to avoid race-condition in DAGNet.
129  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
130  createSharedBuffer<Context>(ws_);
131  }
132  }
133  // Gets the output size. The output channel is manually specified.
134  std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
135  CAFFE_ENFORCE(4 == input.dim());
136  CAFFE_ENFORCE(input.numel() > 0);
137  int N = input.dim32(0);
138  bool channel_first = false; // initialized to suppress compiler warning.
139  int H = 0, W = 0; // initialized to suppress compiler warning.
140  int M = 0;
141  switch (order_) {
142  case StorageOrder::NHWC:
143  channel_first = false;
144  H = input.dim32(1);
145  W = input.dim32(2);
146  M = input.dim32(3);
147  break;
148  case StorageOrder::NCHW:
149  channel_first = true;
150  M = input.dim32(1);
151  H = input.dim32(2);
152  W = input.dim32(3);
153  break;
154  default:
155  LOG(FATAL) << "Unknown Storage order: " << order_;
156  }
157  int output_height = 0, output_width = 0;
158  ComputeSizeAndPad(
159  H,
160  stride_[0],
161  kernel_[0],
162  adj_[0],
163  &pads_[0],
164  &pads_[2],
165  &output_height);
166  ComputeSizeAndPad(
167  W,
168  stride_[1],
169  kernel_[1],
170  adj_[1],
171  &pads_[1],
172  &pads_[3],
173  &output_width);
174  std::vector<int64_t> sizes;
175  if (channel_first) {
176  sizes = {N, output_channel, output_height, output_width};
177  } else {
178  sizes = {N, output_height, output_width, output_channel};
179  }
180  VLOG(2) << "In: N " << N << " M " << M << " H " << H << " W " << W;
181  VLOG(2) << "Out: output_channel " << output_channel << " H "
182  << output_height << " W " << output_width;
183  return sizes;
184  }
185 
186  bool RunOnDevice() override {
187  switch (order_) {
188  case StorageOrder::NHWC:
189  return RunOnDeviceWithOrderNHWC();
190  case StorageOrder::NCHW:
191  return RunOnDeviceWithOrderNCHW();
192  default:
193  LOG(FATAL) << "Unknown storage order: " << order_;
194  }
195  // To suppress old compiler warnings
196  return true;
197  }
198 
199  virtual bool RunOnDeviceWithOrderNCHW() {
200  CAFFE_THROW("Not implemented");
201  }
202 
203  virtual bool RunOnDeviceWithOrderNHWC() {
204  CAFFE_THROW("Not implemented");
205  }
206 
207  virtual ~ConvTransposeUnpoolBase() {}
208 
209  private:
210  LegacyPadding legacy_pad_;
211  int pad_;
212 
213  protected:
214  vector<int> kernel_;
215  vector<int> stride_;
216  vector<int> pads_;
217  vector<int> adj_;
218  StorageOrder order_;
219  bool shared_buffer_;
220  Workspace* ws_;
221 
222  // Accessors for 2D conv params.
223 
224  inline int pad_t() const {
225  return pads_[0];
226  }
227 
228  inline int pad_l() const {
229  return pads_[1];
230  }
231 
232  inline int pad_b() const {
233  return pads_[2];
234  }
235 
236  inline int pad_r() const {
237  return pads_[3];
238  }
239 
240  inline int kernel_h() const {
241  return kernel_[0];
242  }
243 
244  inline int kernel_w() const {
245  return kernel_[1];
246  }
247 
248  inline int stride_h() const {
249  return stride_[0];
250  }
251 
252  inline int stride_w() const {
253  return stride_[1];
254  }
255 
256  inline int adj_h() const {
257  return adj_[0];
258  }
259 
260  inline int adj_w() const {
261  return adj_[1];
262  }
263 
264  inline void ComputeSizeAndPad(
265  const int in_size,
266  const int stride,
267  const int kernel,
268  const int adj,
269  int* pad_head,
270  int* pad_tail,
271  int* out_size) {
272  switch (legacy_pad_) {
273  case LegacyPadding::NOTSET:
274  CAFFE_ENFORCE(*pad_head >= 0);
275  CAFFE_ENFORCE(*pad_tail >= 0);
276  *out_size =
277  (in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail;
278  break;
279  // We handle cases of LegacyPadding::VALID and LegacyPadding::SAME
280  // the same way
281  case LegacyPadding::VALID:
282  case LegacyPadding::SAME:
283  *pad_head = 0;
284  *pad_tail = 0;
285  *out_size = (in_size - 1) * stride + kernel + adj;
286  break;
287  case LegacyPadding::CAFFE_LEGACY_POOLING:
288  LOG(FATAL) << "CAFFE_LEGACY_POOLING is no longer supported.";
289  break;
290  }
291  }
292 };
293 
294 #define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
295  USE_OPERATOR_FUNCTIONS(Context); \
296  using ConvTransposeUnpoolBase<Context>::kernel_; \
297  using ConvTransposeUnpoolBase<Context>::stride_; \
298  using ConvTransposeUnpoolBase<Context>::pads_; \
299  using ConvTransposeUnpoolBase<Context>::adj_; \
300  using ConvTransposeUnpoolBase<Context>::order_; \
301  using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
302  using ConvTransposeUnpoolBase<Context>::ws_
303 
304 } // namespace caffe2
305 
306 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
Definition: any.cpp:108
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70