Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_funhash_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
18 #define CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
19 
20 #include <xxhash.h>
21 #include <array>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 #define HASH_MAGIC 0x9e3779b97f4a7c15
27 
28 #define USE_SIGN
29 
30 namespace caffe2 {
31 
32 template <typename T, class Context>
33 class SparseFunHashOp : public Operator<Context> {
34  public:
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36  SparseFunHashOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws),
38  num_outputs_(
39  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
40  num_segments_(
41  OperatorBase::GetSingleArgument<TIndex>("num_segments", -1)),
42  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
43  CAFFE_ENFORCE(
44  OperatorBase::HasArgument("num_outputs"),
45  "Argument `num_outputs` is missing.");
46  // If alpha is provided, use adaptive hashing parameterized by alpha.
47  adaptive_ = (InputSize() == 5);
48  }
49 
50  bool RunOnDevice() override {
51  const auto& val = Input(0);
52  const auto& key = Input(1);
53  const auto& seg = Input(2);
54  const auto& weight = Input(3);
55 
56  TIndex num_alpha = 1;
57  if (adaptive_) {
58  const auto& alpha = Input(4);
59  num_alpha = alpha.dim(0);
60  }
61 
62  const auto* seg_data = seg.template data<int>();
63 
64  TIndex num_weight = weight.dim(0);
65  TIndex num_nz_ent = seg.dim(0);
66 
67  TIndex n_segments = num_segments_;
68  if (num_segments_ == -1) {
69  for (TIndex i = 0; i < num_nz_ent; ++i) {
70  if (seg_data[i] > n_segments) {
71  n_segments = seg_data[i];
72  }
73  }
74  ++n_segments;
75  }
76 
77  auto* output = Output(0);
78  output->Resize(n_segments, num_outputs_);
79 
80  T* output_data = output->template mutable_data<T>();
81 
82  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
83 
84  const auto* weight_data = weight.template data<T>();
85  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
86  const auto* val_data = val.template data<T>();
87  const auto* key_data = key.template data<TIndex>();
88 
89  for (TIndex j = 0; j < num_nz_ent; ++j) {
90  TIndex cur_seg = seg_data[j];
91  TIndex cur_key = key_data[j];
92  T cur_val = val_data[j];
93  TIndex output_stride = cur_seg * num_outputs_;
94  for (TIndex i = 0; i < num_outputs_; ++i) {
95  T sum = 0;
96  for (TIndex k = 0; k < num_alpha; ++k) {
97  // The hash function takes as input three integers:
98  // 1. feature index
99  // 2. output index
100  // 3. alpha index
101  // 4. magic number to improve hashing
102  hash_data[0] = cur_key;
103  hash_data[1] = i;
104  hash_data[2] = k;
105  hash_data[3] = HASH_MAGIC;
106 
107  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
108 
109 #ifdef USE_SIGN
110  // Use the least significant bit for sign, the rest for weights.
111  TIndex index = (hash >> 1) % num_weight;
112  T cur_weight = weight_data[index];
113  if (hash & 1) {
114  cur_weight = -cur_weight;
115  }
116 #else
117  TIndex index = hash % num_weight;
118  T cur_weight = weight_data[index];
119 #endif
120 
121  if (adaptive_) {
122  sum += cur_weight * alpha_data[k];
123  } else {
124  sum += cur_weight;
125  }
126  }
127  output_data[output_stride + i] += sum * cur_val;
128  }
129  }
130 
131  return true;
132  }
133 
134  protected:
135  TIndex num_outputs_;
136  TIndex num_segments_;
137  uint64_t seed_;
138  std::array<uint64_t, 4> hash_data;
139  bool adaptive_;
140 };
141 
142 template <typename T, class Context>
143 class SparseFunHashGradientOp : public Operator<Context> {
144  public:
145  USE_OPERATOR_CONTEXT_FUNCTIONS;
146  SparseFunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
147  : Operator<Context>(operator_def, ws),
148  num_outputs_(
149  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
150  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
151  adaptive_ = (InputSize() == 6);
152  }
153 
154  bool RunOnDevice() override {
155  const auto& grad_out = Input(0);
156  const auto& val = Input(1);
157  const auto& key = Input(2);
158  const auto& seg = Input(3);
159  const auto& weight = Input(4);
160 
161  TIndex num_alpha = 1;
162  T* grad_alpha_data = 0;
163 
164  if (adaptive_) {
165  const auto& alpha = Input(5);
166  num_alpha = alpha.dim(0);
167  auto* grad_alpha = Output(2);
168  grad_alpha->ResizeLike(alpha);
169  grad_alpha_data = grad_alpha->template mutable_data<T>();
170  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
171  }
172 
173  const auto* seg_data = seg.template data<int>();
174 
175  TIndex num_weight = weight.dim(0);
176  TIndex num_nz_ent = seg.dim(0);
177 
178  TIndex grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
179  auto* grad_weight_val = Output(0);
180  grad_weight_val->Resize(grad_weight_size);
181  T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
182 
183  auto* grad_weight_ind = Output(1);
184  grad_weight_ind->Resize(grad_weight_size);
185  auto* grad_weight_ind_data =
186  grad_weight_ind->template mutable_data<TIndex>();
187 
188  const auto* grad_out_data = grad_out.template data<T>();
189  const auto* weight_data = weight.template data<T>();
190  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
191  const auto* val_data = val.template data<T>();
192  const auto* key_data = key.template data<TIndex>();
193 
194  TIndex w_ind = 0;
195  for (TIndex j = 0; j < num_nz_ent; ++j) {
196  TIndex cur_seg = seg_data[j];
197  TIndex cur_key = key_data[j];
198  T cur_val = val_data[j];
199  TIndex grad_out_stride = cur_seg * num_outputs_;
200  for (TIndex i = 0; i < num_outputs_; ++i) {
201  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
202  for (TIndex k = 0; k < num_alpha; ++k) {
203  hash_data[0] = cur_key;
204  hash_data[1] = i;
205  hash_data[2] = k;
206  hash_data[3] = HASH_MAGIC;
207 
208  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
209 
210  T cur_grad_out_scale = grad_out_scale;
211 #ifdef USE_SIGN
212  TIndex index = (hash >> 1) % num_weight;
213  if (hash & 1) {
214  cur_grad_out_scale = -cur_grad_out_scale;
215  }
216 #else
217  TIndex index = hash % num_weight;
218 #endif
219 
220  if (adaptive_) {
221  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
222  grad_weight_val_data[w_ind] = alpha_data[k] * cur_grad_out_scale;
223  } else {
224  grad_weight_val_data[w_ind] = cur_grad_out_scale;
225  }
226  grad_weight_ind_data[w_ind] = index;
227  ++w_ind;
228  }
229  }
230  }
231  return true;
232  }
233 
234  protected:
235  TIndex num_outputs_;
236  uint64_t seed_;
237  std::array<uint64_t, 4> hash_data;
238  bool adaptive_;
239 };
240 
241 } // namespace caffe2
242 
243 #endif // CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52