Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_to_dense_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
18 #define CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class SparseToDenseOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  USE_DISPATCH_HELPER;
31 
32  SparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws),
34  output_first_dim_(
35  OperatorBase::GetSingleArgument<int>("output_first_dim", 0)) {}
36 
37  bool RunOnDevice() override {
39  this, Input(INDICES));
40  }
41 
42  private:
43  template <typename TInd>
44  int GetOutputFirstDim(
45  const TInd* sparse_indices_vec,
46  const int32_t sparse_indices_len) {
47  if (output_first_dim_ > 0) {
48  CAFFE_ENFORCE_EQ(InputSize(), 2);
49  return output_first_dim_;
50  }
51  if (InputSize() == 3) {
52  auto& data_to_infer_dim = Input(DATA_TO_INFER_DIM);
53  CAFFE_ENFORCE_GE(data_to_infer_dim.ndim(), 1);
54  return data_to_infer_dim.dim32(0);
55  }
56  if (sparse_indices_len <= 0) {
57  return 0;
58  }
59 
60  // Awkward way to get the max element to make it work with both CUDA
61  // and CPU.
62  max_element_.Resize(1);
63  TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
64  math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
65  &scratch_, &context_);
66  max_element_host_.CopyFrom(max_element_);
67  return 1 + max_element_host_.template data<TInd>()[0];
68  }
69 
70  template <typename TInd>
71  bool DoRunWithType() {
72  return DispatchHelper<
74  float,
75  int32_t,
76  int64_t,
78  TInd>::call(this, Input(VALUES));
79  }
80 
81  template <typename TInd, typename TData>
82  bool DoRunWithType2() {
83  auto& sparse_indices = Input(INDICES);
84  CAFFE_ENFORCE_EQ(sparse_indices.ndim(), 1);
85  auto& sparse_values = Input(VALUES);
86  CAFFE_ENFORCE_GE(sparse_values.ndim(), 1);
87  CAFFE_ENFORCE_EQ(sparse_indices.size(), sparse_values.dim(0));
88 
89  const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
90  const int32_t sparse_indices_len = sparse_indices.dim32(0);
91  const int output_first_dim =
92  GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
93 
94  auto shape = sparse_values.dims();
95  shape[0] = output_first_dim;
96  auto* output = Output(0);
97  output->Resize(shape);
98 
99  TData* output_data = output->template mutable_data<TData>();
100  memset(output_data, 0, output->nbytes());
101  const auto block_nitems = sparse_values.size_from_dim(1);
102  const TData* sparse_values_vec = sparse_values.template data<TData>();
103 
104  for (int32_t i = 0; i < sparse_indices_len; i++) {
105  const TInd idx = sparse_indices_vec[i];
106  CAFFE_ENFORCE_GE(idx, 0);
107  CAFFE_ENFORCE_LT(idx, output_first_dim);
108  math::Add(
109  block_nitems,
110  output_data + idx * block_nitems,
111  sparse_values_vec + i * block_nitems,
112  output_data + idx * block_nitems,
113  &context_);
114  }
115  return true;
116  }
117 
118  template <typename TInd>
119  bool DoRunWithOtherType2() {
120  CAFFE_THROW(
121  "SparseToDense is not implemented on tensor of type ",
122  Input(VALUES).meta().name(),
123  "Consider adding it a type in the list DispatchHelper or implementing "
124  "a generic version (which won't work for duplicated indices though)");
125  }
126 
127  private:
128  int output_first_dim_;
129  Tensor<Context> scratch_;
130  Tensor<CPUContext> max_element_host_;
131  Tensor<Context> max_element_;
132 
133  INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
134 };
135 
136 } // namespace caffe2
137 
138 #endif // CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:182
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.