1 #ifndef CAFFE2_OPERATORS_SPACE_BATCH_OP_H_ 2 #define CAFFE2_OPERATORS_SPACE_BATCH_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename Context>
19 CAFFE_ENFORCE(input.dim() == 4);
20 CAFFE_ENFORCE(output->dim() == 4);
22 const int output_batch = output->dim32(0);
23 const int output_depth = output->dim32(1);
24 const int output_height = output->dim32(2);
25 const int output_width = output->dim32(3);
27 const int input_batch = input.dim32(0);
28 const int input_depth = input.dim32(1);
29 const int input_height = input.dim32(2);
30 const int input_width = input.dim32(3);
32 for (
int out_b = 0; out_b < output_batch; ++out_b) {
33 const int in_b = out_b % input_batch;
34 const int offset_w = (out_b / input_batch) % block_size;
35 const int offset_h = (out_b / input_batch) / block_size;
36 for (
int d = 0; d < input_depth; ++d) {
37 for (
int out_h = 0; out_h < output_height; ++out_h) {
38 const int in_h = out_h * block_size + offset_h - pad_t;
39 for (
int out_w = 0; out_w < output_width; ++out_w) {
40 const int in_w = out_w * block_size + offset_w - pad_l;
41 const auto output_offset =
42 ((out_b * output_depth + d) * output_height + out_h) *
45 const auto input_offset =
46 ((in_b * input_depth + d) * input_height + in_h) * input_width +
48 if (in_h >= 0 && in_w >= 0 && in_h < input_height &&
50 output->template mutable_data<float>()[output_offset] =
51 input.template data<float>()[input_offset];
53 output->template mutable_data<float>()[output_offset] = 0.0;
61 template <
typename Context>
69 CAFFE_ENFORCE(input.dim() == 4);
70 CAFFE_ENFORCE(output->dim() == 4);
72 const int output_batch = output->dim32(0);
73 const int output_depth = output->dim32(1);
74 const int output_height = output->dim32(2);
75 const int output_width = output->dim32(3);
77 const int input_batch = input.dim32(0);
78 const int input_depth = input.dim32(1);
79 const int input_height = input.dim32(2);
80 const int input_width = input.dim32(3);
82 CAFFE_ENFORCE(input_depth == output_depth);
83 for (
int in_b = 0; in_b < input_batch; ++in_b) {
84 const int out_b = in_b % output_batch;
85 const int offset_w = (in_b / output_batch) % block_size;
86 const int offset_h = (in_b / output_batch) / block_size;
87 for (
int d = 0; d < input_depth; ++d) {
88 for (
int in_h = 0; in_h < input_height; ++in_h) {
89 const int out_h = in_h * block_size + offset_h - pad_t;
90 for (
int in_w = 0; in_w < input_width; ++in_w) {
91 const int out_w = in_w * block_size + offset_w - pad_l;
92 if (out_h >= 0 && out_w >= 0 && out_h < output_height &&
93 out_w < output_width) {
94 const auto output_offset =
95 ((out_b * output_depth + d) * output_height + out_h) *
98 const auto input_offset =
99 ((in_b * input_depth + d) * input_height + in_h) * input_width +
101 output->template mutable_data<float>()[output_offset] =
102 input.template data<float>()[input_offset];
110 template <
typename Context>
113 USE_OPERATOR_CONTEXT_FUNCTIONS;
114 template <
class... Args>
117 pad_(this->
template GetSingleArgument<int>(
"pad", 0)),
118 pad_t_(this->
template GetSingleArgument<int>(
"pad_t", pad_)),
119 pad_l_(this->
template GetSingleArgument<int>(
"pad", pad_)),
120 pad_b_(this->
template GetSingleArgument<int>(
"pad", pad_)),
121 pad_r_(this->
template GetSingleArgument<int>(
"pad", pad_)),
122 block_size_(this->
template GetSingleArgument<int>(
"block_size", 2)),
123 order_(StringToStorageOrder(
124 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
125 CAFFE_ENFORCE(order_ == StorageOrder::NCHW);
138 template <
typename Context>
141 USE_OPERATOR_CONTEXT_FUNCTIONS;
144 bool RunOnDevice()
override {
145 const auto& input =
Input(0);
146 auto* output = Output(0);
147 const int batch = input.dim32(0);
148 const int depth = input.dim32(1);
149 const int height = this->pad_b_ + this->pad_t_ + input.dim32(2);
150 const int width = this->pad_l_ + this->pad_r_ + input.dim32(3);
152 height % this->block_size_ == 0,
157 CAFFE_ENFORCE(width % this->block_size_ == 0);
159 const int output_batch = batch * this->block_size_ * this->block_size_;
160 const int output_height = height / this->block_size_;
161 const int output_width = width / this->block_size_;
162 Output(0)->Resize(output_batch, depth, output_height, output_width);
164 spaceToBatch<Context>(
176 template <
typename Context>
179 USE_OPERATOR_CONTEXT_FUNCTIONS;
182 bool RunOnDevice()
override {
183 const auto& input =
Input(0);
184 auto* output = Output(0);
185 const int batch = input.dim32(0);
186 const int depth = input.dim32(1);
187 const int height = input.dim32(2);
188 const int width = input.dim32(3);
190 const int output_batch = batch / this->block_size_ / this->block_size_;
191 const int output_height =
192 height * this->block_size_ - this->pad_b_ - this->pad_t_;
193 const int output_width =
194 width * this->block_size_ - this->pad_l_ - this->pad_r_;
195 Output(0)->Resize(output_batch, depth, output_height, output_width);
196 batchToSpace<Context>(
209 #endif // CAFFE2_OPERATORS_SPACE_BATCH_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...