Caffe2 - C++ API
A deep learning, cross platform ML framework
dense_vector_to_id_list_op.h
1 #ifndef CAFFE2_OPERATORS_DENSE_VECTOR_TO_ID_LIST_OP_H_
2 #define CAFFE2_OPERATORS_DENSE_VECTOR_TO_ID_LIST_OP_H_
3 
4 #include <set>
5 #include <vector>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class DenseVectorToIdListOp : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  USE_SIMPLE_CTOR_DTOR(DenseVectorToIdListOp)
16 
17  template <typename T, typename M>
18  bool DoRunWithType() {
19  auto& input = Input(0);
20  const auto* input_data = input.template data<T>();
21 
22  CAFFE_ENFORCE_EQ(input.dim(), 2, "Sample should be 2-D");
23  const auto batch_size = input.size(0);
24  const auto col_num = input.size(1);
25 
26  auto* out_lengths = Output(0, {batch_size}, at::dtype<int32_t>());
27 
28  auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
29 
30  auto* out_values = Output(1, {batch_size * col_num}, at::dtype<M>());
31 
32  auto* out_values_data = out_values->template mutable_data<M>();
33 
34  auto v_pos = 0;
35  auto l_pos = 0;
36  for (auto i = 0; i < batch_size; i++) {
37  auto length = 0;
38  for (int j = 0; j < col_num; j++) {
39  if ((int)(input_data[i * col_num + j] + 0.5) != 0) {
40  out_values_data[v_pos++] = j;
41  length++;
42  }
43  }
44  out_lengths_data[l_pos++] = length;
45  }
46  out_values->Resize(v_pos);
47  out_lengths->Resize(l_pos);
48  return true;
49  }
50 
51  bool RunOnDevice() override {
52  if (Input(0).template IsType<float>()) {
53  return DoRunWithType<float, int>();
54  } else {
55  CAFFE_THROW(
56  "DenseVectorToIdList operator only supports 32-bit float, but",
57  " input was of type ",
58  Input(0).dtype().name());
59  }
60  }
61 };
62 
63 } // namespace caffe2
64 
65 #endif // CAFFE2_OPERATORS_DENSE_VECTOR_TO_ID_LIST_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13