Caffe2 - C++ API
A deep learning, cross platform ML framework
reverse_packed_segs_op.h
1 
17 #ifndef CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
18 #define CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <class Context>
26 class ReversePackedSegsOp final : public Operator<Context> {
27  public:
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29  USE_SIMPLE_CTOR_DTOR(ReversePackedSegsOp);
30  USE_DISPATCH_HELPER;
31 
32  bool RunOnDevice() override {
34  this, Input(DATA));
35  }
36 
37  template <typename T>
38  bool DoRunWithType() {
39  if (Input(LENGTHS).template IsType<int>()) {
40  DoRunWithLengthType<T, int>();
41  } else {
42  DoRunWithLengthType<T, long>();
43  }
44  return true;
45  }
46 
47  private:
48  INPUT_TAGS(DATA, LENGTHS);
49 
50  template <typename T, typename LengthType>
51  void DoRunWithLengthType() {
52  const auto& data = Input(DATA);
53  const auto& lengths = Input(LENGTHS);
54 
55  CAFFE_ENFORCE(
56  data.ndim() == 3,
57  "DATA should be 3-D tensor <lengths, "
58  "segments, embeddings>");
59  CAFFE_ENFORCE(lengths.ndim() == 1, "LENGTH should be 1-D");
60 
61  auto* output = Output(0);
62  const auto& shape = data.dims();
63  output->Resize(shape);
64 
65  const auto& max_length = data.dims()[0];
66  const auto& batch_size = data.dims()[1];
67  const auto& block_size = data.dims()[2];
68  CAFFE_ENFORCE(
69  lengths.dims()[0] == batch_size,
70  "lenths size should be"
71  " equal to batch size");
72 
73  const T* data_ptr = data.template data<T>();
74  const LengthType* lengths_ptr = lengths.template data<LengthType>();
75 
76  vector<LengthType> lengths_host(batch_size);
77  context_.template Copy<LengthType, Context, CPUContext>(
78  batch_size, lengths_ptr, &lengths_host[0]);
79  context_.FinishDeviceComputation();
80 
81  T* rev_data_ptr = output->template mutable_data<T>();
82  for (TIndex i = 0; i < batch_size; i++) {
83  const auto& seg_length = lengths_host[i];
84  CAFFE_ENFORCE_LE(seg_length, max_length);
85  TIndex j = 0;
86  for (; j < seg_length; j++) {
87  const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
88  T* rev_data_block_ptr =
89  rev_data_ptr + ((seg_length - 1 - j) * batch_size + i) * block_size;
90  context_.template Copy<T, Context, Context>(
91  block_size, data_block_ptr, rev_data_block_ptr);
92  }
93  for (; j < max_length; j++) {
94  const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
95  T* rev_data_block_ptr =
96  rev_data_ptr + (j * batch_size + i) * block_size;
97  context_.template Copy<T, Context, Context>(
98  block_size, data_block_ptr, rev_data_block_ptr);
99  }
100  }
101  }
102 };
103 
104 } // namespace caffe2
105 
106 #endif // CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
Copyright (c) 2016-present, Facebook, Inc.