17 #ifndef CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_ 18 #define CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_ 22 #include "caffe2/core/context.h" 23 #include "caffe2/core/operator.h" 24 #include "caffe2/utils/math.h" 26 #define HASH_MAGIC 0x9e3779b97f4a7c15 32 template <
typename T,
class Context>
35 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 OperatorBase::GetSingleArgument<int64_t>(
"num_outputs", -1)),
41 OperatorBase::GetSingleArgument<int64_t>(
"num_segments", -1)),
42 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
45 "Argument `num_outputs` is missing.");
47 adaptive_ = (InputSize() == 5);
50 bool RunOnDevice()
override {
51 const auto& val =
Input(0);
52 const auto& key =
Input(1);
53 const auto& seg =
Input(2);
54 const auto& weight =
Input(3);
56 int64_t num_alpha = 1;
58 const auto& alpha =
Input(4);
59 num_alpha = alpha.size(0);
62 const auto* seg_data = seg.template data<int>();
64 int64_t num_weight = weight.size(0);
65 int64_t num_nz_ent = seg.size(0);
67 int64_t n_segments = num_segments_;
68 if (num_segments_ == -1) {
69 for (int64_t i = 0; i < num_nz_ent; ++i) {
70 if (seg_data[i] > n_segments) {
71 n_segments = seg_data[i];
77 auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
79 T* output_data = output->template mutable_data<T>();
81 memset(output_data, 0,
sizeof(
T) * n_segments * num_outputs_);
83 const auto* weight_data = weight.template data<T>();
84 const auto* alpha_data = adaptive_ ?
Input(4).template data<T>() : 0;
85 const auto* val_data = val.template data<T>();
86 const auto* key_data = key.template data<int64_t>();
88 for (int64_t j = 0; j < num_nz_ent; ++j) {
89 int64_t cur_seg = seg_data[j];
90 int64_t cur_key = key_data[j];
91 T cur_val = val_data[j];
92 int64_t output_stride = cur_seg * num_outputs_;
93 for (int64_t i = 0; i < num_outputs_; ++i) {
95 for (int64_t k = 0; k < num_alpha; ++k) {
101 hash_data[0] = cur_key;
104 hash_data[3] = HASH_MAGIC;
106 uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
110 int64_t index = (hash >> 1) % num_weight;
111 T cur_weight = weight_data[index];
113 cur_weight = -cur_weight;
116 int64_t index = hash % num_weight;
117 T cur_weight = weight_data[index];
121 sum += cur_weight * alpha_data[k];
126 output_data[output_stride + i] += sum * cur_val;
134 int64_t num_outputs_;
135 int64_t num_segments_;
137 std::array<uint64_t, 4> hash_data;
141 template <
typename T,
class Context>
144 USE_OPERATOR_CONTEXT_FUNCTIONS;
148 OperatorBase::GetSingleArgument<int64_t>(
"num_outputs", -1)),
149 seed_(OperatorBase::GetSingleArgument<uint64_t>(
"seed", 0)) {
150 adaptive_ = (InputSize() == 6);
153 bool RunOnDevice()
override {
154 const auto& grad_out =
Input(0);
155 const auto& val =
Input(1);
156 const auto& key =
Input(2);
157 const auto& seg =
Input(3);
158 const auto& weight =
Input(4);
160 int64_t num_alpha = 1;
161 T* grad_alpha_data = 0;
164 const auto& alpha =
Input(5);
165 num_alpha = alpha.size(0);
167 auto* grad_alpha = Output(2, alpha.sizes(), at::dtype<T>());
168 grad_alpha_data = grad_alpha->template mutable_data<T>();
169 memset(grad_alpha_data, 0,
sizeof(
T) * num_alpha);
172 const auto* seg_data = seg.template data<int>();
174 int64_t num_weight = weight.size(0);
175 int64_t num_nz_ent = seg.size(0);
177 int64_t grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
179 auto* grad_weight_val = Output(0, {grad_weight_size}, at::dtype<T>());
180 T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
182 auto* grad_weight_ind = Output(1, {grad_weight_size}, at::dtype<int64_t>());
183 auto* grad_weight_ind_data =
184 grad_weight_ind->template mutable_data<int64_t>();
186 const auto* grad_out_data = grad_out.template data<T>();
187 const auto* weight_data = weight.template data<T>();
188 const auto* alpha_data = adaptive_ ?
Input(5).template data<T>() : 0;
189 const auto* val_data = val.template data<T>();
190 const auto* key_data = key.template data<int64_t>();
193 for (int64_t j = 0; j < num_nz_ent; ++j) {
194 int64_t cur_seg = seg_data[j];
195 int64_t cur_key = key_data[j];
196 T cur_val = val_data[j];
197 int64_t grad_out_stride = cur_seg * num_outputs_;
198 for (int64_t i = 0; i < num_outputs_; ++i) {
199 T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
200 for (int64_t k = 0; k < num_alpha; ++k) {
201 hash_data[0] = cur_key;
204 hash_data[3] = HASH_MAGIC;
206 uint64_t hash = XXH64(hash_data.data(), hash_data.size(), seed_);
208 T cur_grad_out_scale = grad_out_scale;
210 int64_t index = (hash >> 1) % num_weight;
212 cur_grad_out_scale = -cur_grad_out_scale;
215 int64_t index = hash % num_weight;
219 grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
220 grad_weight_val_data[w_ind] = alpha_data[k] * cur_grad_out_scale;
222 grad_weight_val_data[w_ind] = cur_grad_out_scale;
224 grad_weight_ind_data[w_ind] = index;
233 int64_t num_outputs_;
235 std::array<uint64_t, 4> hash_data;
241 #endif // CAFFE2_OPERATORS_SPARSE_FUNHASH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.