1 #ifndef CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_ 2 #define CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_ 4 #include "caffe2/core/operator.h" 5 #include "caffe2/utils/math.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
13 template <
class... Args>
16 OP_SINGLE_ARG(
double,
"padding_value", padding_value_, -1),
17 OP_SINGLE_ARG(
int,
"target_length", target_length_, -1) {
18 CAFFE_ENFORCE_GE(target_length_, 1,
"target_length argument must be >= 1");
21 bool RunOnDevice()
override {
27 bool DoRunWithType() {
28 auto& data =
Input(DATA);
29 auto& lengths =
Input(LENGTHS);
31 CAFFE_ENFORCE_EQ(lengths.dim(), 1,
"LENGTHS must be 1-D");
32 CAFFE_ENFORCE_GE(data.dim(), 1,
"DATA should be at least 1-D");
37 lengths_host_.CopyFrom(lengths);
39 auto lengths_size = lengths_host_.numel();
40 auto* lengths_data = lengths_host_.template data<int32_t>();
42 int32_t total_length = 0;
44 math::Sum<int32_t, CPUContext>(
45 lengths_size, lengths_data, &total_length, &cpuContext);
47 CAFFE_ENFORCE_EQ(total_length, data.size(0));
49 auto shape = data.sizes().vec();
50 shape[0] = lengths_size * target_length_;
51 auto* output = Output(0, shape, at::dtype<T>());
53 auto block_size = data.size_from_dim(1);
54 auto src_data = data.template data<T>();
55 auto out_data = output->template mutable_data<T>();
58 output->numel(),
static_cast<T>(padding_value_), out_data, &context_);
59 for (int64_t i = 0; i < lengths_size; ++i) {
60 auto length = lengths_data[i];
61 CAFFE_ENFORCE_GE(length, 0);
67 " is larger than target length");
69 context_.template CopySameDevice<T>(
70 block_size * length, src_data, out_data);
72 out_data += block_size * target_length_;
73 src_data += block_size * length;
78 INPUT_TAGS(DATA, LENGTHS);
81 double padding_value_;
88 #endif // CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...