1 #ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_ 2 #define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context,
bool Forward>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
20 bool RunOnDevice()
override {
25 template <
typename ValT>
26 bool DoRunWithType() {
29 int dim_offset = Forward ? 1 : 2;
30 auto& values =
Input(0);
31 CAFFE_ENFORCE_GT(values.dim(), dim_offset);
34 int64_t block_size = values.size_from_dim(dim_offset);
35 auto values_vec = values.template data<ValT>();
37 auto& lengths =
Input(LENGTHS);
38 CAFFE_ENFORCE_EQ(lengths.dim(), 1);
39 const auto cols = lengths.numel();
40 const int32_t* lengths_vec = lengths.template data<int32_t>();
44 cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
45 CAFFE_ENFORCE_GE(rows, 0);
48 math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
51 vector<int64_t> shape;
55 shape.push_back(rows);
56 shape.push_back(cols);
58 shape.push_back(length_sum);
62 shape.end(), values.sizes().begin() + dim_offset, values.sizes().end());
64 auto* output = Output(OUTPUTVALUE, shape, at::dtype<ValT>());
66 auto output_data = output->template mutable_data<ValT>();
69 math::Set<ValT, Context>(output->numel(), 0, output_data, &context_);
72 for (
int c = 0; c < cols; c++) {
73 for (
int r = 0; r < lengths_vec[c]; r++) {
74 auto input_offset = Forward ? (offset + r) : (r * cols + c);
75 auto output_offset = Forward ? (r * cols + c) : (offset + r);
76 context_.CopyItemsSameDevice(
79 values_vec + input_offset * block_size,
80 output_data + output_offset * block_size);
82 offset += lengths_vec[c];
88 INPUT_TAGS(INPUTVALUE, LENGTHS);
89 OUTPUT_TAGS(OUTPUTVALUE);
93 #endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...