Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_to_dense_op.h
1 #ifndef CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
2 #define CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class SparseToDenseOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  USE_DISPATCH_HELPER;
15 
16  template <class... Args>
17  explicit SparseToDenseOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...),
19  output_first_dim_(
20  this->template GetSingleArgument<int>("output_first_dim", 0)) {}
21 
22  bool RunOnDevice() override {
24  this, Input(INDICES));
25  }
26 
27  private:
28  template <typename TInd>
29  int GetOutputFirstDim(
30  const TInd* sparse_indices_vec,
31  const int32_t sparse_indices_len) {
32  if (output_first_dim_ > 0) {
33  CAFFE_ENFORCE_EQ(InputSize(), 2);
34  return output_first_dim_;
35  }
36  if (InputSize() == 3) {
37  auto& data_to_infer_dim = Input(DATA_TO_INFER_DIM);
38  CAFFE_ENFORCE_GE(data_to_infer_dim.dim(), 1);
39  return data_to_infer_dim.dim32(0);
40  }
41  if (sparse_indices_len <= 0) {
42  return 0;
43  }
44 
45  // Awkward way to get the max element to make it work with both CUDA
46  // and CPU.
47  ReinitializeTensor(&max_element_, {1}, at::dtype<TInd>().device(Context::GetDeviceType()));
48  TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
49  math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
50  &scratch_, &context_);
51  max_element_host_.CopyFrom(max_element_);
52  return 1 + max_element_host_.template data<TInd>()[0];
53  }
54 
55  template <typename TInd>
56  bool DoRunWithType() {
57  return DispatchHelper<
59  float,
60  int32_t,
61  int64_t,
63  TInd>::call(this, Input(VALUES));
64  }
65 
66  template <typename TInd, typename TData>
67  bool DoRunWithType2() {
68  auto& sparse_indices = Input(INDICES);
69  CAFFE_ENFORCE_EQ(sparse_indices.dim(), 1);
70  auto& sparse_values = Input(VALUES);
71  CAFFE_ENFORCE_GE(sparse_values.dim(), 1);
72  CAFFE_ENFORCE_EQ(sparse_indices.numel(), sparse_values.size(0));
73 
74  const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
75  const int32_t sparse_indices_len = sparse_indices.dim32(0);
76  const int output_first_dim =
77  GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
78 
79  auto shape = sparse_values.sizes().vec();
80  shape[0] = output_first_dim;
81 
82  auto* output = Output(0, shape, at::dtype<TData>());
83 
84  TData* output_data = output->template mutable_data<TData>();
85  if (!output_first_dim) {
86  return true;
87  }
88  memset(output_data, 0, output->nbytes());
89  const auto block_nitems = sparse_values.size_from_dim(1);
90  const TData* sparse_values_vec = sparse_values.template data<TData>();
91 
92  for (int32_t i = 0; i < sparse_indices_len; i++) {
93  const TInd idx = sparse_indices_vec[i];
94  CAFFE_ENFORCE_GE(idx, 0);
95  CAFFE_ENFORCE_LT(idx, output_first_dim);
96  math::Add(
97  block_nitems,
98  output_data + idx * block_nitems,
99  sparse_values_vec + i * block_nitems,
100  output_data + idx * block_nitems,
101  &context_);
102  }
103  return true;
104  }
105 
106  template <typename TInd>
107  bool DoRunWithOtherType2() {
108  CAFFE_THROW(
109  "SparseToDense is not implemented on tensor of type ",
110  Input(VALUES).dtype().name(),
111  "consider adding it as a type in the DispatchHelper list or "
112  "implementing a generic version (which won't work for "
113  "duplicated indices though)");
114  }
115 
116  private:
117  int output_first_dim_;
118  Tensor scratch_{Context::GetDeviceType()};
119  Tensor max_element_host_{CPU};
120  Tensor max_element_;
121 
122  INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
123 };
124 
125 } // namespace caffe2
126 
127 #endif // CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13