Caffe2 - C++ API
A deep learning, cross platform ML framework
variable_length_sequence_padding.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/operator.h"
5 #include "caffe2/utils/eigen_utils.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 namespace detail {
10 
11 template <typename T, typename Context>
12 void VariableLengthSequencePadding(
13  int N,
14  int B,
15  int M,
16  T* X,
17  const int32_t* seqLengths,
18  const T padValue,
19  Context* /*context*/) {
20  for (int j = 0; j < B; j++) {
21  for (int i = seqLengths[j]; i < N; i++) {
22  EigenVectorArrayMap<T>(X + B * M * i + M * j, M).setConstant(padValue);
23  }
24  }
25 }
26 
27 } // namespace detail
28 
29 template <typename T, typename Context>
30 class VariableLengthSequencePaddingOp : public Operator<Context> {
31  public:
32  template <class... Args>
33  explicit VariableLengthSequencePaddingOp(Args&&... args)
34  : Operator<Context>(std::forward<Args>(args)...) {}
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36 
37  bool RunOnDevice() override {
38  const auto N = Input(INPUT).size(0);
39  const auto B = Input(INPUT).size(1);
40  const auto M = Input(INPUT).size(2);
41 
42  auto X = Output(OUTPUT)->template mutable_data<T>();
43 
44  auto seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
45 
46  detail::VariableLengthSequencePadding<T, Context>(
47  N, B, M, X, seqLengths, 0, &context_);
48  return true;
49  }
50 
51  protected:
52  INPUT_TAGS(INPUT, SEQ_LENGTHS);
53  OUTPUT_TAGS(OUTPUT);
54 };
55 
56 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58