Caffe2 - C++ API
A deep learning, cross platform ML framework
merge_id_lists_op.h
1 #ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
2 #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
3 
4 #include <set>
5 #include <vector>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class MergeIdListsOp : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
16 
17  template <typename T>
18  bool DoRunWithType() {
19  auto& first_lengths = Input(0);
20  CAFFE_ENFORCE_EQ(first_lengths.dim(), 1, "LENGTHS should be 1-D");
21  const auto batch_size = first_lengths.numel();
22 
23  auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype<int32_t>());
24 
25  auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
26 
31  auto M = 0;
32  for (size_t i = 0; i < InputSize(); i += 2) {
33  auto& lengths = Input(i);
34  CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D");
35  CAFFE_ENFORCE_EQ(lengths.numel(), batch_size, "LENGTHS should be equal");
36  auto& values = Input(i + 1);
37  CAFFE_ENFORCE_EQ(values.dim(), 1, "VALUES should be 1-D");
38  M += values.numel();
39  }
40 
41  auto* out_values = Output(1, {M}, at::dtype<T>());
42 
43  T* out_values_data = out_values->template mutable_data<T>();
44  auto pos = 0;
45 
46  // TODO(badri): Use unordered_set if performance is an issue
47  std::set<T> deduped;
48  std::vector<int> offsets(InputSize(), 0);
49  for (auto sample = 0; sample < batch_size; sample++) {
50  for (size_t i = 0; i < InputSize(); i += 2) {
51  auto& lengths = Input(i);
52  const auto* lengths_data = lengths.template data<int32_t>();
53 
54  auto& values = Input(i + 1);
55  const T* values_data = values.template data<T>();
56  const auto length = lengths_data[sample];
57 
58  for (auto j = offsets[i]; j < offsets[i] + length; j++) {
59  deduped.insert(values_data[j]);
60  }
61  offsets[i] += length;
62  }
63  for (auto val : deduped) {
64  out_values_data[pos++] = val;
65  }
66  out_lengths_data[sample] = deduped.size();
67  deduped.clear();
68  }
69  out_values->Resize(pos);
70  return true;
71  }
72 
73  bool RunOnDevice() override {
75  }
76 };
77 
78 } // namespace caffe2
79 
80 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13