Caffe2 - C++ API
A deep learning, cross platform ML framework
funhash_op.h
1 
17 #ifndef CAFFE2_OPERATORS_FUNHASH_OP_H_
18 #define CAFFE2_OPERATORS_FUNHASH_OP_H_
19 
20 #include <xxhash.h>
21 #include <array>
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 #define SIGN_MAGIC 0x9e3779b97f4a7c15
27 #define INDEX_MAGIC 0xf39cc0605cedc834
28 
29 #define USE_SIGN
30 
31 namespace caffe2 {
32 
33 template <typename T, class Context>
34 class FunHashOp : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  FunHashOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws),
39  num_outputs_(
40  OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
41  num_segments_(
42  OperatorBase::GetSingleArgument<int64_t>("num_segments", -1)),
43  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
44  CAFFE_ENFORCE(
45  OperatorBase::HasArgument("num_outputs"),
46  "Argument `num_outputs` is missing.");
47  // If alpha is provided, use adaptive hashing parameterized by alpha.
48  adaptive_ = (InputSize() == 5);
49  }
50 
51  bool RunOnDevice() override {
52  const auto& val = Input(0);
53  const auto& key = Input(1);
54  const auto& seg = Input(2);
55  const auto& weight = Input(3);
56 
57  int64_t num_alpha = 1;
58  if (adaptive_) {
59  const auto& alpha = Input(4);
60  num_alpha = alpha.size(0);
61  }
62 
63  const auto* seg_data = seg.template data<int>();
64 
65  int64_t num_weight = weight.size(0);
66  int64_t num_nz_ent = seg.size(0);
67 
68  int64_t n_segments = num_segments_;
69  if (num_segments_ == -1) {
70  for (int64_t i = 0; i < num_nz_ent; ++i) {
71  if (seg_data[i] > n_segments) {
72  n_segments = seg_data[i];
73  }
74  }
75  ++n_segments;
76  }
77 
78  auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
79 
80  T* output_data = output->template mutable_data<T>();
81 
82  memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
83 
84  const auto* weight_data = weight.template data<T>();
85  const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
86  const auto* val_data = val.template data<T>();
87  const auto* key_data = key.template data<int64_t>();
88 
89  for (int64_t j = 0; j < num_nz_ent; ++j) {
90  int64_t cur_seg = seg_data[j];
91  int64_t cur_key = key_data[j];
92  T cur_val = val_data[j];
93  int64_t output_stride = cur_seg * num_outputs_;
94  for (int64_t i = 0; i < num_outputs_; ++i) {
95  T sum = 0;
96  for (int64_t k = 0; k < num_alpha; ++k) {
97  uint64_t hash;
98  // The hash function takes as input four integers:
99  // 1. feature index
100  // 2. output index
101  // 3. alpha index
102  // 4. magic number: SIGN_MAGIC for sign (-1/+1)
103  // INDEX_MAGIC for weight index
104  hash_data[0] = cur_key;
105  hash_data[1] = i;
106  hash_data[2] = k;
107 
108  hash_data[3] = INDEX_MAGIC;
109  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
110  int64_t index = hash % num_weight;
111 
112  T cur_weight = weight_data[index];
113 #ifdef USE_SIGN
114  hash_data[3] = SIGN_MAGIC;
115  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
116  if (hash % 2) {
117  cur_weight = -cur_weight;
118  }
119 #endif // USE_SIGN
120 
121  if (adaptive_) {
122  sum += cur_weight * alpha_data[k];
123  } else {
124  sum += cur_weight;
125  }
126  }
127  output_data[output_stride + i] += sum * cur_val;
128  }
129  }
130 
131  return true;
132  }
133 
134  protected:
135  int64_t num_outputs_;
136  int64_t num_segments_;
137  uint64_t seed_;
138  std::array<uint64_t, 4> hash_data;
139  bool adaptive_;
140 };
141 
142 template <typename T, class Context>
143 class FunHashGradientOp : public Operator<Context> {
144  public:
145  USE_OPERATOR_CONTEXT_FUNCTIONS;
146  FunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
147  : Operator<Context>(operator_def, ws),
148  num_outputs_(
149  OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
150  seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
151  adaptive_ = (InputSize() == 6);
152  }
153 
154  bool RunOnDevice() override {
155  const auto& grad_out = Input(0);
156  const auto& val = Input(1);
157  const auto& key = Input(2);
158  const auto& seg = Input(3);
159  const auto& weight = Input(4);
160 
161  int64_t num_alpha = 1;
162  T* grad_alpha_data = 0;
163 
164  if (adaptive_) {
165  const auto& alpha = Input(5);
166  num_alpha = alpha.size(0);
167 
168  auto* grad_alpha = Output(1, alpha.sizes(), at::dtype<T>());
169  grad_alpha_data = grad_alpha->template mutable_data<T>();
170  memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
171  }
172 
173  const auto* seg_data = seg.template data<int>();
174 
175  int64_t num_weight = weight.size(0);
176  int64_t num_nz_ent = seg.size(0);
177 
178  auto* grad_weight = Output(0, weight.sizes(), at::dtype<T>());
179  T* grad_weight_data = grad_weight->template mutable_data<T>();
180 
181  const auto* grad_out_data = grad_out.template data<T>();
182  const auto* weight_data = weight.template data<T>();
183  const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
184  const auto* val_data = val.template data<T>();
185  const auto* key_data = key.template data<int64_t>();
186 
187  memset(grad_weight_data, 0, sizeof(T) * num_weight);
188 
189  for (int64_t j = 0; j < num_nz_ent; ++j) {
190  int64_t cur_seg = seg_data[j];
191  int64_t cur_key = key_data[j];
192  T cur_val = val_data[j];
193  int64_t grad_out_stride = cur_seg * num_outputs_;
194  for (int64_t i = 0; i < num_outputs_; ++i) {
195  T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
196  for (int64_t k = 0; k < num_alpha; ++k) {
197  uint64_t hash;
198  hash_data[0] = cur_key;
199  hash_data[1] = i;
200  hash_data[2] = k;
201 
202  hash_data[3] = INDEX_MAGIC;
203  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
204  int64_t index = hash % num_weight;
205 
206  T cur_grad_out_scale = grad_out_scale;
207 #ifdef USE_SIGN
208  hash_data[3] = SIGN_MAGIC;
209  hash = XXH64(hash_data.data(), hash_data.size(), seed_);
210  if (hash % 2) {
211  cur_grad_out_scale = -cur_grad_out_scale;
212  }
213 #endif // USE_SIGN
214 
215  if (adaptive_) {
216  grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
217  grad_weight_data[index] += alpha_data[k] * cur_grad_out_scale;
218  } else {
219  grad_weight_data[index] += cur_grad_out_scale;
220  }
221  }
222  }
223  }
224  return true;
225  }
226 
227  protected:
228  int64_t num_outputs_;
229  uint64_t seed_;
230  std::array<uint64_t, 4> hash_data;
231  bool adaptive_;
232 };
233 
234 } // namespace caffe2
235 
236 #endif // CAFFE2_OPERATORS_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70