1 #ifndef CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_ 2 #define CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 bool RunOnDevice()
override {
22 bool DoRunWithType() {
23 if (
Input(LENGTHS).
template IsType<int>()) {
24 DoRunWithLengthType<T, int>();
26 DoRunWithLengthType<T, long>();
32 INPUT_TAGS(DATA, LENGTHS);
34 template <
typename T,
typename LengthType>
35 void DoRunWithLengthType() {
36 const auto& data =
Input(DATA);
37 const auto& lengths =
Input(LENGTHS);
41 "DATA should be 3-D tensor <lengths, " 42 "segments, embeddings>");
43 CAFFE_ENFORCE(lengths.dim() == 1,
"LENGTH should be 1-D");
45 const auto shape = data.sizes();
46 auto* output = Output(0, shape, at::dtype<T>());
48 const auto max_length = data.sizes()[0];
49 const auto batch_size = data.sizes()[1];
50 const auto block_size = data.sizes()[2];
52 lengths.sizes()[0] == batch_size,
53 "lenths size should be" 54 " equal to batch size");
56 const T* data_ptr = data.template data<T>();
57 const LengthType* lengths_ptr = lengths.template data<LengthType>();
59 vector<LengthType> lengths_host(batch_size);
60 context_.template CopyToCPU<LengthType>(
61 batch_size, lengths_ptr, &lengths_host[0]);
62 context_.FinishDeviceComputation();
64 T* rev_data_ptr = output->template mutable_data<T>();
65 for (int64_t i = 0; i < batch_size; i++) {
66 const auto& seg_length = lengths_host[i];
67 CAFFE_ENFORCE_LE(seg_length, max_length);
69 for (; j < seg_length; j++) {
70 const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
71 T* rev_data_block_ptr =
72 rev_data_ptr + ((seg_length - 1 - j) * batch_size + i) * block_size;
73 context_.template CopySameDevice<T>(
74 block_size, data_block_ptr, rev_data_block_ptr);
76 for (; j < max_length; j++) {
77 const T* data_block_ptr = data_ptr + (j * batch_size + i) * block_size;
78 T* rev_data_block_ptr =
79 rev_data_ptr + (j * batch_size + i) * block_size;
80 context_.template CopySameDevice<T>(
81 block_size, data_block_ptr, rev_data_block_ptr);
89 #endif // CAFFE2_OPERATORS_REVERSE_PACKED_SEGS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...