1 #ifndef CAFFE2_OPERATORS_IM2COL_OP_H_ 2 #define CAFFE2_OPERATORS_IM2COL_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
18 pad_(this->
template GetSingleArgument<int>(
"pad", 0)),
19 kernel_h_(this->
template GetSingleArgument<int>(
21 this->
template GetSingleArgument<int>(
"kernel", 0))),
22 kernel_w_(this->
template GetSingleArgument<int>(
24 this->
template GetSingleArgument<int>(
"kernel", 0))),
25 dilation_h_(this->
template GetSingleArgument<int>(
27 this->
template GetSingleArgument<int>(
"dilation", 1))),
28 dilation_w_(this->
template GetSingleArgument<int>(
30 this->
template GetSingleArgument<int>(
"dilation", 1))),
31 stride_h_(this->
template GetSingleArgument<int>(
33 this->
template GetSingleArgument<int>(
"stride", 1))),
34 stride_w_(this->
template GetSingleArgument<int>(
36 this->
template GetSingleArgument<int>(
"stride", 1))),
37 order_(StringToStorageOrder(
38 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
39 CAFFE_ENFORCE(kernel_h_ > 0);
40 CAFFE_ENFORCE(kernel_w_ > 0);
41 CAFFE_ENFORCE(dilation_h_ > 0);
42 CAFFE_ENFORCE(dilation_w_ > 0);
43 CAFFE_ENFORCE(stride_h_ > 0);
44 CAFFE_ENFORCE(stride_w_ > 0);
45 CAFFE_ENFORCE(pad_ >= 0);
48 bool RunOnDevice()
override {
51 CAFFE_ENFORCE(4 == X.dim());
53 int N = 0,
C = 0, H = 0, W = 0;
55 case StorageOrder::NCHW:
61 case StorageOrder::NHWC:
68 CAFFE_THROW(
"Unknown storage order: ", order_);
71 const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
72 const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
73 CAFFE_ENFORCE(H >= dkernel_h);
74 CAFFE_ENFORCE(W >= dkernel_w);
75 const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
76 const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
79 case StorageOrder::NCHW: {
82 std::vector<int64_t>{N,
C * kernel_h_ * kernel_w_, out_h, out_w},
85 const size_t dx = X.numel() / N;
86 const size_t dy = Y->numel() / N;
87 for (
int n = 0; n < N; ++n) {
88 const auto* xdata = X.template data<T>() + (n * dx);
89 auto* ydata = Y->template mutable_data<T>() + (n * dy);
90 math::Im2Col<T, Context, StorageOrder::NCHW>(
109 case StorageOrder::NHWC: {
112 std::vector<int64_t>{N, out_h, out_w, kernel_h_ * kernel_w_ *
C},
115 const size_t dx = X.numel() / N;
116 const size_t dy = Y->numel() / N;
117 for (
int n = 0; n < N; ++n) {
118 const auto* xdata = X.template data<T>() + (n * dx);
119 auto* ydata = Y->template mutable_data<T>() + (n * dy);
120 math::Im2Col<T, Context, StorageOrder::NHWC>(
140 CAFFE_THROW(
"Unknown storage order: ", order_);
157 template <
typename T,
class Context>
160 USE_OPERATOR_CONTEXT_FUNCTIONS;
161 template <
class... Args>
164 pad_(this->
template GetSingleArgument<int>(
"pad", 0)),
165 kernel_h_(this->
template GetSingleArgument<int>(
167 this->
template GetSingleArgument<int>(
"kernel", 0))),
168 kernel_w_(this->
template GetSingleArgument<int>(
170 this->
template GetSingleArgument<int>(
"kernel", 0))),
171 dilation_h_(this->
template GetSingleArgument<int>(
173 this->
template GetSingleArgument<int>(
"dilation", 1))),
174 dilation_w_(this->
template GetSingleArgument<int>(
176 this->
template GetSingleArgument<int>(
"dilation", 1))),
177 stride_h_(this->
template GetSingleArgument<int>(
179 this->
template GetSingleArgument<int>(
"stride", 1))),
180 stride_w_(this->
template GetSingleArgument<int>(
182 this->
template GetSingleArgument<int>(
"stride", 1))),
183 order_(StringToStorageOrder(
184 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
185 CAFFE_ENFORCE(kernel_h_ > 0);
186 CAFFE_ENFORCE(kernel_w_ > 0);
187 CAFFE_ENFORCE(dilation_h_ > 0);
188 CAFFE_ENFORCE(dilation_w_ > 0);
189 CAFFE_ENFORCE(stride_h_ > 0);
190 CAFFE_ENFORCE(stride_w_ > 0);
191 CAFFE_ENFORCE(pad_ >= 0);
194 bool RunOnDevice()
override {
198 auto* Y = Output(0, Z.sizes(), at::dtype<T>());
199 CAFFE_ENFORCE(4 == Y->dim());
201 int N = 0,
C = 0, H = 0, W = 0;
203 case StorageOrder::NCHW:
209 case StorageOrder::NHWC:
216 CAFFE_THROW(
"Unknown storage order: ", order_);
219 const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
220 const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
221 CAFFE_ENFORCE(H >= dkernel_h);
222 CAFFE_ENFORCE(W >= dkernel_w);
223 const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
224 const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
225 CAFFE_ENFORCE(X.numel() == N * kernel_h_ * kernel_w_ *
C * out_h * out_w);
227 const size_t dx = X.numel() / N;
228 const size_t dy = Y->numel() / N;
232 case StorageOrder::NCHW: {
233 for (
int n = 0; n < N; ++n) {
234 const auto* xdata = X.template data<T>() + (n * dx);
235 auto* ydata = Y->template mutable_data<T>() + (n * dy);
236 math::Col2Im<T, Context, StorageOrder::NCHW>(
255 case StorageOrder::NHWC: {
256 for (
int n = 0; n < N; ++n) {
257 const auto* xdata = X.template data<T>() + (n * dx);
258 auto* ydata = Y->template mutable_data<T>() + (n * dy);
259 math::Col2Im<T, Context, StorageOrder::NHWC>(
279 CAFFE_THROW(
"Unknown storage order: ", order_);
298 #endif // CAFFE2_OPERATORS_IM2COL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...