Caffe2 - C++ API
A deep learning, cross platform ML framework
reverse_packed_segs_op.h
1 #ifndef CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
2 #define CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class ReversePackedSegsOp final : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  USE_SIMPLE_CTOR_DTOR(ReversePackedSegsOp);
14  USE_DISPATCH_HELPER;
15 
16  bool RunOnDevice() override {
18  this, Input(DATA));
19  }
20 
21  template <typename T>
22  bool DoRunWithType() {
23  if (Input(LENGTHS).template IsType<int>()) {
24  DoRunWithLengthType<T, int>();
25  } else {
26  DoRunWithLengthType<T, long>();
27  }
28  return true;
29  }
30 
31  private:
32  INPUT_TAGS(DATA, LENGTHS);
33 
34  template <typename T, typename LengthType>
35  void DoRunWithLengthType() {
36  const auto& data = Input(DATA);
37  const auto& lengths = Input(LENGTHS);
38 
39  CAFFE_ENFORCE(
40  data.dim() == 3,
41  "DATA should be 3-D tensor <lengths, "
42  "segments, embeddings>");
43  CAFFE_ENFORCE(lengths.dim() == 1, "LENGTH should be 1-D");
44 
45  const auto shape = data.sizes();
46  auto* output = Output(0, shape, at::dtype<T>());
47 
48  const auto max_length = data.sizes()[0];
49  const auto batch_size = data.sizes()[1];
50  const auto block_size = data.sizes()[2];
51  CAFFE_ENFORCE(
52  lengths.sizes()[0] == batch_size,
53  "lenths size should be"
54  " equal to batch size");
55 
56  const T* data_ptr = data.template data<T>();
57  const LengthType* lengths_ptr = lengths.template data<LengthType>();
58 
59  vector<LengthType> lengths_host(batch_size);
60  context_.template CopyToCPU<LengthType>(
61  batch_size, lengths_ptr, &lengths_host[0]);
62  context_.FinishDeviceComputation();
63 
64  T* rev_data_ptr = output->template mutable_data<T>();
65  for (int64_t i = 0; i < batch_size; i++) {
66  const auto& seg_length = lengths_host[i];
67  CAFFE_ENFORCE_LE(seg_length, max_length);
68  int64_t j = 0;
69  for (; j < seg_length; j++) {
70  const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
71  T* rev_data_block_ptr =
72  rev_data_ptr + ((seg_length - 1 - j) * batch_size + i) * block_size;
73  context_.template CopySameDevice<T>(
74  block_size, data_block_ptr, rev_data_block_ptr);
75  }
76  for (; j < max_length; j++) {
77  const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
78  T* rev_data_block_ptr =
79  rev_data_ptr + (j * batch_size + i) * block_size;
80  context_.template CopySameDevice<T>(
81  block_size, data_block_ptr, rev_data_block_ptr);
82  }
83  }
84  }
85 };
86 
87 } // namespace caffe2
88 
89 #endif // CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13