Caffe2 - C++ API
A deep learning, cross platform ML framework
merge_id_lists_op.h
1 
17 #ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
18 #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
19 
20 #include <set>
21 #include <vector>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class MergeIdListsOp : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
32 
33  template <typename T>
34  bool DoRunWithType() {
35  auto& first_lengths = Input(0);
36  CAFFE_ENFORCE_EQ(first_lengths.ndim(), 1, "LENGTHS should be 1-D");
37  const auto batch_size = first_lengths.size();
38 
39  auto* out_lengths = Output(0);
40  out_lengths->ResizeLike(first_lengths);
41 
42  auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
43 
48  auto M = 0;
49  for (size_t i = 0; i < InputSize(); i += 2) {
50  auto& lengths = Input(i);
51  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS should be 1-D");
52  CAFFE_ENFORCE_EQ(lengths.size(), batch_size, "LENGTHS should be equal");
53  auto& values = Input(i + 1);
54  CAFFE_ENFORCE_EQ(values.ndim(), 1, "VALUES should be 1-D");
55  M += values.size();
56  }
57 
58  auto* out_values = Output(1);
59  out_values->Resize(M);
60 
61  T* out_values_data = out_values->template mutable_data<T>();
62  auto pos = 0;
63 
64  // TODO(badri): Use unordered_set if performance is an issue
65  std::set<T> deduped;
66  std::vector<int> offsets(InputSize(), 0);
67  for (auto sample = 0; sample < batch_size; sample++) {
68  for (size_t i = 0; i < InputSize(); i += 2) {
69  auto& lengths = Input(i);
70  const auto* lengths_data = lengths.template data<int32_t>();
71 
72  auto& values = Input(i + 1);
73  const T* values_data = values.template data<T>();
74  const auto length = lengths_data[sample];
75 
76  for (auto j = offsets[i]; j < offsets[i] + length; j++) {
77  deduped.insert(values_data[j]);
78  }
79  offsets[i] += length;
80  }
81  for (auto val : deduped) {
82  out_values_data[pos++] = val;
83  }
84  out_lengths_data[sample] = deduped.size();
85  deduped.clear();
86  }
87  out_values->Resize(pos);
88  return true;
89  }
90 
91  bool RunOnDevice() override {
92  return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
93  }
94 };
95 
96 } // namespace caffe2
97 
98 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
Copyright (c) 2016-present, Facebook, Inc.