Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_rnn_sequence_op.h
1 
17 #ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
18 #define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
19 
20 #include <algorithm>
21 #include <vector>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 namespace caffe2 {
27 
28 template <class Context, bool Forward>
29 class PackRNNSequenceOpBase : public Operator<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  PackRNNSequenceOpBase(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws) {}
34 
35  bool RunOnDevice() override {
37  this, Input(0));
38  }
39 
40  template <typename ValT>
41  bool DoRunWithType() {
42  // The value is copied from the sequence to the pack
43  // if Forward is true, and vice versa
44  int dim_offset = Forward ? 1 : 2;
45  auto& values = Input(0);
46  CAFFE_ENFORCE_GT(values.ndim(), dim_offset);
47 
48  // block_size is the size for each individual feature
49  TIndex block_size = values.size_from_dim(dim_offset);
50  auto values_vec = values.template data<ValT>();
51 
52  auto& lengths = Input(LENGTHS);
53  CAFFE_ENFORCE_EQ(lengths.ndim(), 1);
54  const auto cols = lengths.size();
55  const int32_t* lengths_vec = lengths.template data<int32_t>();
56  // the total number of rows is defined as the max number from lengths
57  // if when the lengths is empty, we set rows = 0 to support zero lengths
58  const auto rows =
59  cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
60  CAFFE_ENFORCE_GE(rows, 0);
61  int length_sum = 0;
62  if (cols > 0) {
63  math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
64  }
65 
66  vector<TIndex> shape;
67  // the output shape is rows * cols for the pack,
68  // or length_sum for the sequence
69  if (Forward) {
70  shape.push_back(rows);
71  shape.push_back(cols);
72  } else {
73  shape.push_back(length_sum);
74  }
75  // insert the dim for the feature
76  shape.insert(
77  shape.end(), values.dims().begin() + dim_offset, values.dims().end());
78 
79  auto* output = Output(OUTPUTVALUE);
80  output->Resize(shape);
81 
82  auto output_data = output->template mutable_data<ValT>();
83  // initialize output_data with zero, as it is the default value for padding
84  // when certain length is smaller than rows
85  math::Set<ValT, Context>(output->size(), 0, output_data, &context_);
86 
87  int32_t offset = 0;
88  for (int c = 0; c < cols; c++) {
89  for (int r = 0; r < lengths_vec[c]; r++) {
90  auto input_offset = Forward ? (offset + r) : (r * cols + c);
91  auto output_offset = Forward ? (r * cols + c) : (offset + r);
92  context_.template CopyItems<Context, Context>(
93  values.meta(),
94  block_size,
95  values_vec + input_offset * block_size,
96  output_data + output_offset * block_size);
97  }
98  offset += lengths_vec[c];
99  }
100  return true;
101  }
102 
103  private:
104  INPUT_TAGS(INPUTVALUE, LENGTHS);
105  OUTPUT_TAGS(OUTPUTVALUE);
106 };
107 } // namespace caffe2
108 
109 #endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.