Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_pad_op.h
1 #ifndef CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_
2 #define CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_
3 
4 #include "caffe2/core/operator.h"
5 #include "caffe2/utils/math.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class LengthsPadOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  template <class... Args>
14  explicit LengthsPadOp(Args&&... args)
15  : Operator<Context>(std::forward<Args>(args)...),
16  OP_SINGLE_ARG(double, "padding_value", padding_value_, -1),
17  OP_SINGLE_ARG(int, "target_length", target_length_, -1) {
18  CAFFE_ENFORCE_GE(target_length_, 1, "target_length argument must be >= 1");
19  }
20 
21  bool RunOnDevice() override {
23  this, Input(DATA));
24  }
25 
26  template <typename T>
27  bool DoRunWithType() {
28  auto& data = Input(DATA);
29  auto& lengths = Input(LENGTHS);
30 
31  CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS must be 1-D");
32  CAFFE_ENFORCE_GE(data.dim(), 1, "DATA should be at least 1-D");
33 
34  // Context::CopyFrom and math::Sum need the same context to avoid race
35  // conditions
36  // why? CPUContext is not used in Sum
37  lengths_host_.CopyFrom(lengths);
38 
39  auto lengths_size = lengths_host_.numel();
40  auto* lengths_data = lengths_host_.template data<int32_t>();
41 
42  int32_t total_length = 0;
43  CPUContext cpuContext;
44  math::Sum<int32_t, CPUContext>(
45  lengths_size, lengths_data, &total_length, &cpuContext);
46 
47  CAFFE_ENFORCE_EQ(total_length, data.size(0));
48 
49  auto shape = data.sizes().vec();
50  shape[0] = lengths_size * target_length_;
51  auto* output = Output(0, shape, at::dtype<T>());
52 
53  auto block_size = data.size_from_dim(1);
54  auto src_data = data.template data<T>();
55  auto out_data = output->template mutable_data<T>();
56 
57  math::Set(
58  output->numel(), static_cast<T>(padding_value_), out_data, &context_);
59  for (int64_t i = 0; i < lengths_size; ++i) {
60  auto length = lengths_data[i];
61  CAFFE_ENFORCE_GE(length, 0);
62  CAFFE_ENFORCE_GE(
63  target_length_,
64  length,
65  "Length at index = ",
66  i,
67  " is larger than target length");
68 
69  context_.template CopySameDevice<T>(
70  block_size * length, src_data, out_data);
71 
72  out_data += block_size * target_length_;
73  src_data += block_size * length;
74  }
75  return true;
76  }
77 
78  INPUT_TAGS(DATA, LENGTHS);
79 
80  private:
81  double padding_value_;
82  int target_length_;
83  Tensor lengths_host_{CPU};
84 };
85 
86 } // namespace caffe2
87 
88 #endif // CAFFE2_OPERATORS_LENGTHS_PAD_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13