Caffe2 - C++ API
A deep learning, cross platform ML framework
funhash_op.h
1 
17 #ifndef CAFFE2_OPERATORS_FUNHASH_OP_H_
18 #define CAFFE2_OPERATORS_FUNHASH_OP_H_
19 
20 #include <xxhash.h>
21 #include <array>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 #define SIGN_MAGIC 0x9e3779b97f4a7c15
27 #define INDEX_MAGIC 0xf39cc0605cedc834
28 
29 #define USE_SIGN
30 
31 namespace caffe2 {
32 
33 template <typename T, class Context>
34 class FunHashOp : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  FunHashOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws),
39  num_outputs_(
40  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
41  num_segments_(
42  OperatorBase::GetSingleArgument<TIndex>("num_segments", -1)),
43  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
44  CAFFE_ENFORCE(
45  OperatorBase::HasArgument("num_outputs"),
46  "Argument `num_outputs` is missing.");
47  // If alpha is provided, use adaptive hashing parameterized by alpha.
48  adaptive_ = (InputSize() == 5);
49  }
50 
51  bool RunOnDevice() override {
52  const auto& val = Input(0);
53  const auto& key = Input(1);
54  const auto& seg = Input(2);
55  const auto& weight = Input(3);
56 
57  TIndex num_alpha = 1;
58  if (adaptive_) {
59  const auto& alpha = Input(4);
60  num_alpha = alpha.dim(0);
61  }
62 
63  const auto* seg_data = seg.template data<int>();
64 
65  TIndex num_weight = weight.dim(0);
66  TIndex num_nz_ent = seg.dim(0);
67 
68  TIndex n_segments = num_segments_;
69  if (num_segments_ == -1) {
70  for (TIndex i = 0; i < num_nz_ent; ++i) {
71  if (seg_data[i] > n_segments) {
72  n_segments = seg_data[i];
73  }
74  }
75  ++n_segments;
76  }
77 
78  auto* output = Output(0);
79  output->Resize(n_segments, num_outputs_);
80 
81  T* output_data = output->template mutable_data<T>();
82 
83  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
84 
85  const auto* weight_data = weight.template data<T>();
86  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
87  const auto* val_data = val.template data<T>();
88  const auto* key_data = key.template data<TIndex>();
89 
90  for (TIndex j = 0; j < num_nz_ent; ++j) {
91  TIndex cur_seg = seg_data[j];
92  TIndex cur_key = key_data[j];
93  T cur_val = val_data[j];
94  TIndex output_stride = cur_seg * num_outputs_;
95  for (TIndex i = 0; i < num_outputs_; ++i) {
96  T sum = 0;
97  for (TIndex k = 0; k < num_alpha; ++k) {
98  uint64_t hash;
99  // The hash function takes as input four integers:
100  // 1. feature index
101  // 2. output index
102  // 3. alpha index
103  // 4. magic number: SIGN_MAGIC for sign (-1/+1)
104  // INDEX_MAGIC for weight index
105  hash_data[0] = cur_key;
106  hash_data[1] = i;
107  hash_data[2] = k;
108 
109  hash_data[3] = INDEX_MAGIC;
110  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
111  TIndex index = hash % num_weight;
112 
113  T cur_weight = weight_data[index];
114 #ifdef USE_SIGN
115  hash_data[3] = SIGN_MAGIC;
116  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
117  if (hash % 2) {
118  cur_weight = -cur_weight;
119  }
120 #endif // USE_SIGN
121 
122  if (adaptive_) {
123  sum += cur_weight * alpha_data[k];
124  } else {
125  sum += cur_weight;
126  }
127  }
128  output_data[output_stride + i] += sum * cur_val;
129  }
130  }
131 
132  return true;
133  }
134 
135  protected:
136  TIndex num_outputs_;
137  TIndex num_segments_;
138  uint64_t seed_;
139  std::array<uint64_t, 4> hash_data;
140  bool adaptive_;
141 };
142 
143 template <typename T, class Context>
144 class FunHashGradientOp : public Operator<Context> {
145  public:
146  USE_OPERATOR_CONTEXT_FUNCTIONS;
147  FunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
148  : Operator<Context>(operator_def, ws),
149  num_outputs_(
150  OperatorBase::GetSingleArgument<TIndex>("num_outputs", -1)),
151  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
152  adaptive_ = (InputSize() == 6);
153  }
154 
155  bool RunOnDevice() override {
156  const auto& grad_out = Input(0);
157  const auto& val = Input(1);
158  const auto& key = Input(2);
159  const auto& seg = Input(3);
160  const auto& weight = Input(4);
161 
162  TIndex num_alpha = 1;
163  T* grad_alpha_data = 0;
164 
165  if (adaptive_) {
166  const auto& alpha = Input(5);
167  num_alpha = alpha.dim(0);
168  auto* grad_alpha = Output(1);
169  grad_alpha->ResizeLike(alpha);
170  grad_alpha_data = grad_alpha->template mutable_data<T>();
171  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
172  }
173 
174  const auto* seg_data = seg.template data<int>();
175 
176  TIndex num_weight = weight.dim(0);
177  TIndex num_nz_ent = seg.dim(0);
178 
179  auto* grad_weight = Output(0);
180  grad_weight->ResizeLike(weight);
181  T* grad_weight_data = grad_weight->template mutable_data<T>();
182 
183  const auto* grad_out_data = grad_out.template data<T>();
184  const auto* weight_data = weight.template data<T>();
185  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
186  const auto* val_data = val.template data<T>();
187  const auto* key_data = key.template data<TIndex>();
188 
189  memset(grad_weight_data, 0, sizeof(T) * num_weight);
190 
191  for (TIndex j = 0; j < num_nz_ent; ++j) {
192  TIndex cur_seg = seg_data[j];
193  TIndex cur_key = key_data[j];
194  T cur_val = val_data[j];
195  TIndex grad_out_stride = cur_seg * num_outputs_;
196  for (TIndex i = 0; i < num_outputs_; ++i) {
197  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
198  for (TIndex k = 0; k < num_alpha; ++k) {
199  uint64_t hash;
200  hash_data[0] = cur_key;
201  hash_data[1] = i;
202  hash_data[2] = k;
203 
204  hash_data[3] = INDEX_MAGIC;
205  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
206  TIndex index = hash % num_weight;
207 
208  T cur_grad_out_scale = grad_out_scale;
209 #ifdef USE_SIGN
210  hash_data[3] = SIGN_MAGIC;
211  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
212  if (hash % 2) {
213  cur_grad_out_scale = -cur_grad_out_scale;
214  }
215 #endif // USE_SIGN
216 
217  if (adaptive_) {
218  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
219  grad_weight_data[index] += alpha_data[k] * cur_grad_out_scale;
220  } else {
221  grad_weight_data[index] += cur_grad_out_scale;
222  }
223  }
224  }
225  }
226  return true;
227  }
228 
229  protected:
230  TIndex num_outputs_;
231  uint64_t seed_;
232  std::array<uint64_t, 4> hash_data;
233  bool adaptive_;
234 };
235 
236 } // namespace caffe2
237 
238 #endif // CAFFE2_OPERATORS_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52