1 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_ 2 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/operators/conv_op_shared.h" 8 #include "caffe2/operators/conv_pool_op_base.h" 9 #include "caffe2/proto/caffe2_legacy.pb.h" 10 #include "caffe2/utils/math.h" 12 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
16 template <
class Context>
19 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 static_cast<LegacyPadding>(this->
template GetSingleArgument<int>(
25 LegacyPadding::NOTSET))),
26 kernel_(this->
template GetRepeatedArgument<int>(
"kernels")),
27 stride_(this->
template GetRepeatedArgument<int>(
"strides")),
28 pads_(this->
template GetRepeatedArgument<int>(
"pads")),
29 adj_(this->
template GetRepeatedArgument<int>(
"adjs")),
30 order_(StringToStorageOrder(
31 this->
template GetSingleArgument<string>(
"order",
"NCHW"))),
33 this->
template GetSingleArgument<int>(
"shared_buffer", 0)),
37 if (legacy_pad_ == LegacyPadding::VALID ||
38 legacy_pad_ == LegacyPadding::SAME) {
41 "If you use legacy padding VALID or SAME, you should not specify " 42 "any specific padding values.");
46 kernel_.resize(2, this->
template GetSingleArgument<int>(
"kernel", 0));
50 kernel_.push_back(this->
template GetSingleArgument<int>(
"kernel_h", 0));
51 kernel_.push_back(this->
template GetSingleArgument<int>(
"kernel_w", 0));
55 stride_.resize(2, this->
template GetSingleArgument<int>(
"stride", 0));
59 stride_.push_back(this->
template GetSingleArgument<int>(
"stride_h", 0));
60 stride_.push_back(this->
template GetSingleArgument<int>(
"stride_w", 0));
64 adj_.resize(2, this->
template GetSingleArgument<int>(
"adj", 0));
68 adj_.push_back(this->
template GetSingleArgument<int>(
"adj_h", 0));
69 adj_.push_back(this->
template GetSingleArgument<int>(
"adj_w", 0));
74 legacy_pad_ != LegacyPadding::VALID &&
75 legacy_pad_ != LegacyPadding::SAME,
76 "If you use legacy padding VALID or SAME, you should not specify " 77 "any specific padding values.");
78 pads_.resize(4, this->
template GetSingleArgument<int>(
"pad", 0));
85 legacy_pad_ != LegacyPadding::VALID &&
86 legacy_pad_ != LegacyPadding::SAME,
87 "If you use legacy padding VALID or SAME, you should not specify " 88 "any specific padding values.");
89 pads_.push_back(this->
template GetSingleArgument<int>(
"pad_t", 0));
90 pads_.push_back(this->
template GetSingleArgument<int>(
"pad_l", 0));
91 pads_.push_back(this->
template GetSingleArgument<int>(
"pad_b", 0));
92 pads_.push_back(this->
template GetSingleArgument<int>(
"pad_r", 0));
96 if (kernel_.size() == 0) {
97 kernel_.assign({0, 0});
100 if (stride_.size() == 0) {
101 stride_.resize(kernel_.size(), 1);
104 if (pads_.size() == 0) {
105 pads_.resize(kernel_.size() * 2, 0);
108 if (adj_.size() == 0) {
109 adj_.resize(kernel_.size(), 0);
112 CAFFE_ENFORCE_EQ(stride_.size(), kernel_.size());
113 CAFFE_ENFORCE_EQ(adj_.size(), kernel_.size());
115 if (legacy_pad_ != LegacyPadding::VALID &&
116 legacy_pad_ != LegacyPadding::SAME) {
117 CAFFE_ENFORCE_EQ(pads_.size(), 2 * kernel_.size());
120 for (
int dim = 0; dim < kernel_.size(); ++dim) {
121 CAFFE_ENFORCE_GT(kernel_[dim], 0);
122 CAFFE_ENFORCE_GT(stride_[dim], 0);
123 CAFFE_ENFORCE_GE(adj_[dim], 0);
124 CAFFE_ENFORCE_LE(adj_[dim], stride_[dim]);
129 if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
130 createSharedBuffer<Context>(ws_);
134 std::vector<int64_t> GetOutputSize(
const Tensor& input,
int output_channel) {
135 CAFFE_ENFORCE(4 == input.dim());
136 CAFFE_ENFORCE(input.numel() > 0);
137 int N = input.dim32(0);
138 bool channel_first =
false;
142 case StorageOrder::NHWC:
143 channel_first =
false;
148 case StorageOrder::NCHW:
149 channel_first =
true;
155 LOG(FATAL) <<
"Unknown Storage order: " << order_;
157 int output_height = 0, output_width = 0;
174 std::vector<int64_t> sizes;
176 sizes = {N, output_channel, output_height, output_width};
178 sizes = {N, output_height, output_width, output_channel};
180 VLOG(2) <<
"In: N " << N <<
" M " << M <<
" H " << H <<
" W " << W;
181 VLOG(2) <<
"Out: output_channel " << output_channel <<
" H " 182 << output_height <<
" W " << output_width;
186 bool RunOnDevice()
override {
188 case StorageOrder::NHWC:
189 return RunOnDeviceWithOrderNHWC();
190 case StorageOrder::NCHW:
191 return RunOnDeviceWithOrderNCHW();
193 LOG(FATAL) <<
"Unknown storage order: " << order_;
199 virtual bool RunOnDeviceWithOrderNCHW() {
200 CAFFE_THROW(
"Not implemented");
203 virtual bool RunOnDeviceWithOrderNHWC() {
204 CAFFE_THROW(
"Not implemented");
210 LegacyPadding legacy_pad_;
224 inline int pad_t()
const {
228 inline int pad_l()
const {
232 inline int pad_b()
const {
236 inline int pad_r()
const {
240 inline int kernel_h()
const {
244 inline int kernel_w()
const {
248 inline int stride_h()
const {
252 inline int stride_w()
const {
256 inline int adj_h()
const {
260 inline int adj_w()
const {
264 inline void ComputeSizeAndPad(
272 switch (legacy_pad_) {
273 case LegacyPadding::NOTSET:
274 CAFFE_ENFORCE(*pad_head >= 0);
275 CAFFE_ENFORCE(*pad_tail >= 0);
277 (in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail;
281 case LegacyPadding::VALID:
282 case LegacyPadding::SAME:
285 *out_size = (in_size - 1) * stride + kernel + adj;
287 case LegacyPadding::CAFFE_LEGACY_POOLING:
288 LOG(FATAL) <<
"CAFFE_LEGACY_POOLING is no longer supported.";
294 #define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \ 295 USE_OPERATOR_FUNCTIONS(Context); \ 296 using ConvTransposeUnpoolBase<Context>::kernel_; \ 297 using ConvTransposeUnpoolBase<Context>::stride_; \ 298 using ConvTransposeUnpoolBase<Context>::pads_; \ 299 using ConvTransposeUnpoolBase<Context>::adj_; \ 300 using ConvTransposeUnpoolBase<Context>::order_; \ 301 using ConvTransposeUnpoolBase<Context>::shared_buffer_; \ 302 using ConvTransposeUnpoolBase<Context>::ws_ 306 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.