Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_rnn_sequence_op.h
1 #ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
2 #define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
3 
4 #include <algorithm>
5 #include <vector>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context, bool Forward>
13 class PackRNNSequenceOpBase : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  template <class... Args>
17  explicit PackRNNSequenceOpBase(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...) {}
19 
20  bool RunOnDevice() override {
22  this, Input(0));
23  }
24 
25  template <typename ValT>
26  bool DoRunWithType() {
27  // The value is copied from the sequence to the pack
28  // if Forward is true, and vice versa
29  int dim_offset = Forward ? 1 : 2;
30  auto& values = Input(0);
31  CAFFE_ENFORCE_GT(values.dim(), dim_offset);
32 
33  // block_size is the size for each individual feature
34  int64_t block_size = values.size_from_dim(dim_offset);
35  auto values_vec = values.template data<ValT>();
36 
37  auto& lengths = Input(LENGTHS);
38  CAFFE_ENFORCE_EQ(lengths.dim(), 1);
39  const auto cols = lengths.numel();
40  const int32_t* lengths_vec = lengths.template data<int32_t>();
41  // the total number of rows is defined as the max number from lengths
42  // if when the lengths is empty, we set rows = 0 to support zero lengths
43  const auto rows =
44  cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
45  CAFFE_ENFORCE_GE(rows, 0);
46  int length_sum = 0;
47  if (cols > 0) {
48  math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
49  }
50 
51  vector<int64_t> shape;
52  // the output shape is rows * cols for the pack,
53  // or length_sum for the sequence
54  if (Forward) {
55  shape.push_back(rows);
56  shape.push_back(cols);
57  } else {
58  shape.push_back(length_sum);
59  }
60  // insert the dim for the feature
61  shape.insert(
62  shape.end(), values.sizes().begin() + dim_offset, values.sizes().end());
63 
64  auto* output = Output(OUTPUTVALUE, shape, at::dtype<ValT>());
65 
66  auto output_data = output->template mutable_data<ValT>();
67  // initialize output_data with zero, as it is the default value for padding
68  // when certain length is smaller than rows
69  math::Set<ValT, Context>(output->numel(), 0, output_data, &context_);
70 
71  int32_t offset = 0;
72  for (int c = 0; c < cols; c++) {
73  for (int r = 0; r < lengths_vec[c]; r++) {
74  auto input_offset = Forward ? (offset + r) : (r * cols + c);
75  auto output_offset = Forward ? (r * cols + c) : (offset + r);
76  context_.CopyItemsSameDevice(
77  values.dtype(),
78  block_size,
79  values_vec + input_offset * block_size,
80  output_data + output_offset * block_size);
81  }
82  offset += lengths_vec[c];
83  }
84  return true;
85  }
86 
87  private:
88  INPUT_TAGS(INPUTVALUE, LENGTHS);
89  OUTPUT_TAGS(OUTPUTVALUE);
90 };
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13