1 #ifndef CAFFE2_OPERATORS_SEQUENCE_OPS_H_ 2 #define CAFFE2_OPERATORS_SEQUENCE_OPS_H_ 4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
18 this->
template GetSingleArgument<int>(
"padding_width", 1)),
20 this->
template GetSingleArgument<int>(
"end_padding_width", -1)) {
21 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
22 if (endPaddingWidth_ < 0) {
23 endPaddingWidth_ = startPaddingWidth_;
27 bool RunOnDevice()
override {
28 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
29 Output(0)->Resize(std::vector<int64_t>(0));
30 auto output_0_data = Output(0)->template mutable_data<int64_t>();
32 math::Set<int64_t, Context>(
33 Output(0)->numel(), 0, output_0_data, &context_);
34 if (OutputSize() == 2) {
35 Output(1)->Resize(std::vector<int64_t>(0));
36 auto output_1_data = Output(1)->template mutable_data<int64_t>();
37 math::Set<int64_t, Context>(
38 Output(1)->numel(), 0, output_1_data, &context_);
47 bool DoRunWithType() {
48 const auto& in =
Input(0);
49 CAFFE_ENFORCE_GE(in.dim(), 1);
50 const int32_t outer_size = in.sizes()[0];
51 const auto block_size = in.size_from_dim(1);
52 const auto pad_width = startPaddingWidth_ + endPaddingWidth_;
55 const int32_t* lengths_ptr = &outer_size;
56 int64_t lengths_size = 1;
57 if (InputSize() > 1) {
58 const auto& lengths =
Input(1);
59 lengths_ptr = lengths.template data<int32_t>();
60 lengths_size = lengths.numel();
62 std::vector<int64_t> padShape(in.sizes().begin() + 1, in.sizes().end());
64 Output(0)->Resize(padShape);
65 T* padding_start_ptr = Output(0)->template mutable_data<T>();
66 math::Set<T, Context>(block_size, 0.0, padding_start_ptr, &context_);
69 T* padding_end_ptr = padding_start_ptr;
70 if (OutputSize() == 2) {
71 Output(1)->Resize(padShape);
72 padding_end_ptr = Output(1)->template mutable_data<T>();
73 math::Set<T, Context>(block_size, 0.0, padding_end_ptr, &context_);
80 in.template data<T>(),
91 const int lengths_size,
95 const int* lengths_ptr,
99 int startPaddingWidth_;
100 int endPaddingWidth_;
102 Tensor lengths_prefix_sum_buffer_{Context::GetDeviceType()};
103 Tensor lengths_prefix_sum_{Context::GetDeviceType()};
106 template <
class Context>
109 USE_OPERATOR_CONTEXT_FUNCTIONS;
110 template <
class... Args>
114 this->
template GetSingleArgument<int>(
"padding_width", 1)),
116 this->
template GetSingleArgument<int>(
"end_padding_width", -1)) {
117 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
118 if (endPaddingWidth_ < 0) {
119 endPaddingWidth_ = startPaddingWidth_;
123 bool RunOnDevice()
override {
124 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
125 Output(0)->CopyFrom(
Input(0),
true );
126 if (OutputSize() == 2) {
127 Output(1)->CopyFrom(
Input(1),
true );
135 template <
typename T>
136 bool DoRunWithType();
139 int startPaddingWidth_;
140 int endPaddingWidth_;
143 Tensor lengths_prefix_sum_buffer_{Context::GetDeviceType()};
144 Tensor lengths_prefix_sum_{Context::GetDeviceType()};
147 template <
class Context>
150 USE_OPERATOR_CONTEXT_FUNCTIONS;
151 template <
class... Args>
155 this->
template GetSingleArgument<int>(
"padding_width", 1)),
157 this->
template GetSingleArgument<int>(
"end_padding_width", -1)) {
158 CAFFE_ENFORCE_GE(startPaddingWidth_, 0);
159 if (endPaddingWidth_ < 0) {
160 endPaddingWidth_ = startPaddingWidth_;
164 bool RunOnDevice()
override {
165 if (startPaddingWidth_ == 0 && endPaddingWidth_ == 0) {
166 Output(0)->CopyFrom(
Input(0),
true );
167 if (OutputSize() == 2) {
168 Output(1)->CopyFrom(
Input(1),
true );
176 template <
typename T>
177 bool DoRunWithType() {
178 const auto& in =
Input(0);
179 CAFFE_ENFORCE_GE(in.dim(), 1);
180 const int32_t outer_size = in.sizes()[0];
181 const auto block_size = in.size_from_dim(1);
184 const int32_t* lengths_ptr =
nullptr;
185 int32_t lengths_size = 1;
186 if (InputSize() > 1) {
187 const auto& lengths =
Input(1);
188 lengths_ptr = lengths.template data<int32_t>();
189 lengths_size = lengths.numel();
196 const T* padding_start_ptr =
nullptr;
197 const T* padding_end_ptr =
nullptr;
198 if (InputSize() >= 3) {
199 auto& padding_start =
Input(2);
200 CAFFE_ENFORCE_EQ(block_size, padding_start.numel());
201 padding_start_ptr = padding_start.template data<T>();
203 if (InputSize() == 4) {
204 auto& padding_end =
Input(3);
205 CAFFE_ENFORCE_EQ(block_size, padding_end.numel());
206 padding_end_ptr = padding_end.template data<T>();
208 padding_end_ptr = padding_start_ptr;
211 auto out_dims = in.sizes().vec();
212 out_dims[0] += (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
213 auto* out = Output(0, std::move(out_dims), at::dtype<T>());
215 const auto* in_ptr = in.template data<T>();
216 auto* out_ptr = out->template mutable_data<T>();
218 return MakePadding<T>(
230 template <
typename T>
234 const int32_t* lengths_ptr,
235 int32_t lengths_size,
237 const T* padding_start_ptr,
238 const T* padding_end_ptr,
241 int startPaddingWidth_;
242 int endPaddingWidth_;
245 Tensor lengths_prefix_sum_buffer_{Context::GetDeviceType()};
246 Tensor lengths_prefix_sum_{Context::GetDeviceType()};
249 template <
class Context>
252 USE_OPERATOR_CONTEXT_FUNCTIONS;
253 template <
class... Args>
257 bool RunOnDevice()
override;
262 #endif // CAFFE2_OPERATORS_SEQUENCE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...