Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_funhash_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
18 #define CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
19 
20 #include <xxhash.h>
21 #include <array>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 #define HASH_MAGIC 0x9e3779b97f4a7c15
27 
28 #define USE_SIGN
29 
30 namespace caffe2 {
31 
32 template <typename T, class Context>
33 class SparseFunHashOp : public Operator<Context> {
34  public:
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36  SparseFunHashOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws),
38  num_outputs_(
39  OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
40  num_segments_(
41  OperatorBase::GetSingleArgument<int64_t>("num_segments", -1)),
42  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
43  CAFFE_ENFORCE(
44  OperatorBase::HasArgument("num_outputs"),
45  "Argument `num_outputs` is missing.");
46  // If alpha is provided, use adaptive hashing parameterized by alpha.
47  adaptive_ = (InputSize() == 5);
48  }
49 
50  bool RunOnDevice() override {
51  const auto& val = Input(0);
52  const auto& key = Input(1);
53  const auto& seg = Input(2);
54  const auto& weight = Input(3);
55 
56  int64_t num_alpha = 1;
57  if (adaptive_) {
58  const auto& alpha = Input(4);
59  num_alpha = alpha.size(0);
60  }
61 
62  const auto* seg_data = seg.template data<int>();
63 
64  int64_t num_weight = weight.size(0);
65  int64_t num_nz_ent = seg.size(0);
66 
67  int64_t n_segments = num_segments_;
68  if (num_segments_ == -1) {
69  for (int64_t i = 0; i < num_nz_ent; ++i) {
70  if (seg_data[i] > n_segments) {
71  n_segments = seg_data[i];
72  }
73  }
74  ++n_segments;
75  }
76 
77  auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
78 
79  T* output_data = output->template mutable_data<T>();
80 
81  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
82 
83  const auto* weight_data = weight.template data<T>();
84  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
85  const auto* val_data = val.template data<T>();
86  const auto* key_data = key.template data<int64_t>();
87 
88  for (int64_t j = 0; j < num_nz_ent; ++j) {
89  int64_t cur_seg = seg_data[j];
90  int64_t cur_key = key_data[j];
91  T cur_val = val_data[j];
92  int64_t output_stride = cur_seg * num_outputs_;
93  for (int64_t i = 0; i < num_outputs_; ++i) {
94  T sum = 0;
95  for (int64_t k = 0; k < num_alpha; ++k) {
96  // The hash function takes as input three integers:
97  // 1. feature index
98  // 2. output index
99  // 3. alpha index
100  // 4. magic number to improve hashing
101  hash_data[0] = cur_key;
102  hash_data[1] = i;
103  hash_data[2] = k;
104  hash_data[3] = HASH_MAGIC;
105 
106  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
107 
108 #ifdef USE_SIGN
109  // Use the least significant bit for sign, the rest for weights.
110  int64_t index = (hash >> 1) % num_weight;
111  T cur_weight = weight_data[index];
112  if (hash & 1) {
113  cur_weight = -cur_weight;
114  }
115 #else
116  int64_t index = hash % num_weight;
117  T cur_weight = weight_data[index];
118 #endif
119 
120  if (adaptive_) {
121  sum += cur_weight * alpha_data[k];
122  } else {
123  sum += cur_weight;
124  }
125  }
126  output_data[output_stride + i] += sum * cur_val;
127  }
128  }
129 
130  return true;
131  }
132 
133  protected:
134  int64_t num_outputs_;
135  int64_t num_segments_;
136  uint64_t seed_;
137  std::array<uint64_t, 4> hash_data;
138  bool adaptive_;
139 };
140 
141 template <typename T, class Context>
142 class SparseFunHashGradientOp : public Operator<Context> {
143  public:
144  USE_OPERATOR_CONTEXT_FUNCTIONS;
145  SparseFunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
146  : Operator<Context>(operator_def, ws),
147  num_outputs_(
148  OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
149  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
150  adaptive_ = (InputSize() == 6);
151  }
152 
153  bool RunOnDevice() override {
154  const auto& grad_out = Input(0);
155  const auto& val = Input(1);
156  const auto& key = Input(2);
157  const auto& seg = Input(3);
158  const auto& weight = Input(4);
159 
160  int64_t num_alpha = 1;
161  T* grad_alpha_data = 0;
162 
163  if (adaptive_) {
164  const auto& alpha = Input(5);
165  num_alpha = alpha.size(0);
166 
167  auto* grad_alpha = Output(2, alpha.sizes(), at::dtype<T>());
168  grad_alpha_data = grad_alpha->template mutable_data<T>();
169  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
170  }
171 
172  const auto* seg_data = seg.template data<int>();
173 
174  int64_t num_weight = weight.size(0);
175  int64_t num_nz_ent = seg.size(0);
176 
177  int64_t grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
178 
179  auto* grad_weight_val = Output(0, {grad_weight_size}, at::dtype<T>());
180  T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
181 
182  auto* grad_weight_ind = Output(1, {grad_weight_size}, at::dtype<int64_t>());
183  auto* grad_weight_ind_data =
184  grad_weight_ind->template mutable_data<int64_t>();
185 
186  const auto* grad_out_data = grad_out.template data<T>();
187  const auto* weight_data = weight.template data<T>();
188  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
189  const auto* val_data = val.template data<T>();
190  const auto* key_data = key.template data<int64_t>();
191 
192  int64_t w_ind = 0;
193  for (int64_t j = 0; j < num_nz_ent; ++j) {
194  int64_t cur_seg = seg_data[j];
195  int64_t cur_key = key_data[j];
196  T cur_val = val_data[j];
197  int64_t grad_out_stride = cur_seg * num_outputs_;
198  for (int64_t i = 0; i < num_outputs_; ++i) {
199  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
200  for (int64_t k = 0; k < num_alpha; ++k) {
201  hash_data[0] = cur_key;
202  hash_data[1] = i;
203  hash_data[2] = k;
204  hash_data[3] = HASH_MAGIC;
205 
206  uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
207 
208  T cur_grad_out_scale = grad_out_scale;
209 #ifdef USE_SIGN
210  int64_t index = (hash >> 1) % num_weight;
211  if (hash & 1) {
212  cur_grad_out_scale = -cur_grad_out_scale;
213  }
214 #else
215  int64_t index = hash % num_weight;
216 #endif
217 
218  if (adaptive_) {
219  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
220  grad_weight_val_data[w_ind] = alpha_data[k] * cur_grad_out_scale;
221  } else {
222  grad_weight_val_data[w_ind] = cur_grad_out_scale;
223  }
224  grad_weight_ind_data[w_ind] = index;
225  ++w_ind;
226  }
227  }
228  }
229  return true;
230  }
231 
232  protected:
233  int64_t num_outputs_;
234  uint64_t seed_;
235  std::array<uint64_t, 4> hash_data;
236  bool adaptive_;
237 };
238 
239 } // namespace caffe2
240 
241 #endif // CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70