1 #ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ 2 #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 auto& first_lengths =
Input(0);
20 CAFFE_ENFORCE_EQ(first_lengths.dim(), 1,
"LENGTHS should be 1-D");
21 const auto batch_size = first_lengths.numel();
23 auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype<int32_t>());
25 auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
32 for (
size_t i = 0; i < InputSize(); i += 2) {
33 auto& lengths =
Input(i);
34 CAFFE_ENFORCE_EQ(lengths.dim(), 1,
"LENGTHS should be 1-D");
35 CAFFE_ENFORCE_EQ(lengths.numel(), batch_size,
"LENGTHS should be equal");
36 auto& values =
Input(i + 1);
37 CAFFE_ENFORCE_EQ(values.dim(), 1,
"VALUES should be 1-D");
41 auto* out_values = Output(1, {M}, at::dtype<T>());
43 T* out_values_data = out_values->template mutable_data<T>();
48 std::vector<int> offsets(InputSize(), 0);
49 for (
auto sample = 0; sample < batch_size; sample++) {
50 for (
size_t i = 0; i < InputSize(); i += 2) {
51 auto& lengths =
Input(i);
52 const auto* lengths_data = lengths.template data<int32_t>();
54 auto& values =
Input(i + 1);
55 const T* values_data = values.template data<T>();
56 const auto length = lengths_data[sample];
58 for (
auto j = offsets[i]; j < offsets[i] + length; j++) {
59 deduped.insert(values_data[j]);
63 for (
auto val : deduped) {
64 out_values_data[pos++] = val;
66 out_lengths_data[sample] = deduped.size();
69 out_values->Resize(pos);
73 bool RunOnDevice()
override {
80 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...