Caffe2 - C++ API
A deep learning, cross platform ML framework
h_softmax_op.cc
1 #include "caffe2/operators/h_softmax_op.h"
2 
3 #include <queue>
4 #include <stack>
5 
6 namespace caffe2 {
7 
8 template <>
9 float HSoftmaxOp<float, CPUContext>::RunForwardSingle(const float* X,
10  const float* W, const float* b, int target, float* int_output,
11  const float* bias_multiplier, int dim_out, int dim_in,
12  int& int_output_offset) {
13 
14  // W * x
15  float* fc_output_data = int_output + int_output_offset;
16 
17  math::Gemm<float, CPUContext>(CblasNoTrans, CblasTrans, 1, dim_out, dim_in, 1,
18  X, W, 0, fc_output_data, &context_);
19  math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, 1,
20  b, bias_multiplier, 1, fc_output_data, &context_);
21 
22  int_output_offset += dim_out;
23 
24  //Softmax
25  float* softmax_output_data = int_output + int_output_offset;
26 
27  if (!scale_.has_value()) {
28  scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
29  }
30 
31  if (!sum_multiplier_.has_value()) {
32  sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
33  math::Set<float, CPUContext>(dim_out, 1.f,
34  sum_multiplier_->mutable_data<float>(), &context_);
35  } else if (sum_multiplier_->numel() != dim_out) {
36  sum_multiplier_->Resize(dim_out);
37  math::Set<float, CPUContext>(dim_out, 1.f,
38  sum_multiplier_->mutable_data<float>(), &context_);
39  }
40  math::RowwiseMax<float, CPUContext>(1, dim_out, fc_output_data,
41  scale_->mutable_data<float>(), &context_);
42 
43  // Put the intermediate result X - max(X) into Y
44  context_.template CopyFromCPU<float>(
45  dim_out, fc_output_data, softmax_output_data);
46  // Subtract the scale
47  math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, -1,
48  sum_multiplier_->data<float>(), scale_->data<float>(), 1, softmax_output_data,
49  &context_);
50 
51  // Exponentiation
52  math::Exp<float, CPUContext>(dim_out, softmax_output_data,
53  softmax_output_data, &context_);
54  math::Gemv<float, CPUContext>(CblasNoTrans, 1, dim_out, 1,
55  softmax_output_data, sum_multiplier_->data<float>(), 0,
56  scale_->mutable_data<float>(), &context_);
57 
58  // Do division
59  const float scale = *(scale_->data<float>());
60  for (int j = 0; j < dim_out; ++j) {
61  softmax_output_data[j] /= scale;
62  }
63 
64  int_output_offset += dim_out;
65 
66  if (target < 0) {
67  return -1;
68  }
69  //Return cross entropy loss
70  return -log(std::max(softmax_output_data[target], kLOG_THRESHOLD()));
71 }
72 
73 // Implementation for the CPU context.
74 template <>
75 bool HSoftmaxOp<float, CPUContext>::RunOnDevice() {
76  auto& X = Input(0);
77  const auto& W = Input(1);
78  const auto& b = Input(2);
79  auto& label = Input(3);
80 
81  // Batch size
82  int M = X.dim() > 1 ? X.dim32(0) : 1;
83  // Input feature dimension
84  int K = X.numel() / M;
85  CAFFE_ENFORCE_GE(W.dim(), 2); // N*K
86  CAFFE_ENFORCE_EQ(b.dim(), 1); // N
87  CAFFE_ENFORCE_EQ(K, W.numel() / (W.dim32(0)));
88  // Sum of output dimensions of all hierarchy nodes
89  int N = W.dim32(0);
90  CAFFE_ENFORCE_EQ(N, b.dim32(0));
91  auto* Y = Output(0, {M}, at::dtype<float>());
92  auto* Ydata = Y->template mutable_data<float>();
93  math::Set<float, CPUContext>(M, 0.f, Ydata, &context_);
94  const auto* labeldata = label.data<int>();
95 
96  auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_);
97  int int_output_size = getIntermediateOutputSize(labeldata, M, hierarchy);
98  auto* intermediate_output = Output(1, {int_output_size}, at::dtype<float>());
99  float* int_output_data = intermediate_output->template mutable_data<float>();
100  int int_output_offset = 0;
101 
102  if (!bias_multiplier_.has_value()) {
103  bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
104  math::Set<float, CPUContext>(M, static_cast<float>(1),
105  bias_multiplier_->mutable_data<float>(), &context_);
106  } else if (bias_multiplier_->numel() != M) {
107  bias_multiplier_->Resize(M);
108  math::Set<float, CPUContext>(M, static_cast<float>(1),
109  bias_multiplier_->mutable_data<float>(), &context_);
110  }
111 
112  for (int sample = 0; sample < M; ++sample) {
113  int word_id = labeldata[sample];
114  const PathProto& path = hierarchy[word_id];
115  for (const PathNodeProto& node : path.path_nodes()) {
116  //Offset of node's weight matrix in W
117  int w_offset = node.index();
118  //Number of output dimensions in node's weight matrix
119  int w_length = node.length();
120  int target = node.target();
121  //Adding log probabilities
122  Ydata[sample] += RunForwardSingle(X.data<float>() + sample*K,
123  W.data<float>() + w_offset*K, b.data<float>() + w_offset, target,
124  int_output_data, bias_multiplier_->data<float>()+sample, w_length, K,
125  int_output_offset);
126  }
127  }
128  return true;
129 }
130 
131 template <>
132 void HSoftmaxGradientOp<float, CPUContext>::RunBackwardSingle(const float* X,
133  const float* dY, const float* W, int target,
134  const float* int_output, float* dX, float* dW, float* db, float* dint_output,
135  int dim_in, int dim_out, int& int_output_offset) {
136 
137  //Cross entropy
138  // dX_entropy is the dX for the cross entropy layer
139  float* dX_entropy = dint_output + int_output_offset - dim_out;
140  // X_entropy is the X for the cross entropy layer and Y for the softmax layer
141  const float* X_entropy = int_output + int_output_offset - dim_out;
142 
143  math::Set<float, CPUContext>(dim_out, 0.f, dX_entropy, &context_);
144  dX_entropy[target] = - (*dY) / std::max(X_entropy[target], kLOG_THRESHOLD());
145 
146  int_output_offset -= dim_out;
147 
148  //Softmax
149  if (!scale_.has_value()) {
150  scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
151  }
152  float* scaledata = scale_->mutable_data<float>();
153 
154  if (!sum_multiplier_.has_value()) {
155  sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
156  math::Set<float, CPUContext>(dim_out, 1.f,
157  sum_multiplier_->mutable_data<float>(), &context_);
158  } else if (sum_multiplier_->numel() != dim_out) {
159  sum_multiplier_->Resize(dim_out);
160  math::Set<float, CPUContext>(dim_out, 1.f,
161  sum_multiplier_->mutable_data<float>(), &context_);
162  }
163 
164  float* dX_softmax = dint_output + int_output_offset - dim_out;
165  context_.CopyFromCPU<float>(dim_out, dX_entropy, dX_softmax);
166 
167  math::Dot<float, CPUContext>(dim_out, X_entropy, dX_entropy, scaledata,
168  &context_);
169  math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, -1,
170  sum_multiplier_->data<float>(), scaledata , 1, dX_softmax, &context_);
171  math::Mul<float, CPUContext>(dim_out, dX_softmax, X_entropy, dX_softmax,
172  &context_);
173 
174  int_output_offset -= dim_out;
175 
176  //FC
177  if (!bias_multiplier_.has_value()) {
178  // If the helper bias multiplier has not been created, reshape and fill
179  // it with 1
180  bias_multiplier_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
181  math::Set<float, CPUContext>(1, static_cast<float>(1),
182  bias_multiplier_->template mutable_data<float>(), &context_);
183  }
184 
185  // Compute dW and add incrementally
186  // dW = dW + dX_softmax'*X
187  math::Gemm<float, CPUContext>(CblasTrans, CblasNoTrans, dim_out, dim_in, 1, 1,
188  dX_softmax, X, 1, dW, &context_);
189 
190  // Compute dB and add incrementally
191  // db = db + dX_softmax*bias_multiplier_
192  math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, 1, dX_softmax,
193  bias_multiplier_->template data<float>(), 1, db, &context_);
194 
195  // Compute dX and add incrementally
196  // dX = dX + W'dX_softmax
197  math::Gemv<float, CPUContext>(CblasTrans, dim_out, dim_in,
198  1, W, dX_softmax, 1, dX, &context_);
199 }
200 
201 // Implementation for the CPU context.
202 template <>
203 bool HSoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
204  auto& X = Input(0);
205  const auto& W = Input(1);
206  const auto& b = Input(2);
207  auto& label = Input(3);
208  auto& intermediate_output = Input(4);
209  auto& dY = Input(5);
210 
211  auto* dX = Output(0, X.sizes(), at::dtype<float>());
212  auto* dW = Output(1, W.sizes(), at::dtype<float>());
213  auto* db = Output(2, b.sizes(), at::dtype<float>());
214  auto* dX_intermediate_output =
215  Output(3, intermediate_output.sizes(), at::dtype<float>());
216 
217  float* dX_data = dX->template mutable_data<float>();
218  float* dW_data = dW->template mutable_data<float>();
219  float* db_data = db->template mutable_data<float>();
220  float* dOutput_data = dX_intermediate_output->template mutable_data<float>();
221 
222  math::Set<float, CPUContext>(X.numel(), 0.f, dX_data, &context_);
223  math::Set<float, CPUContext>(W.numel(), 0.f, dW_data, &context_);
224  math::Set<float, CPUContext>(b.numel(), 0.f, db_data, &context_);
225  math::Set<float, CPUContext>(
226  intermediate_output.numel(), 0.f, dOutput_data, &context_);
227 
228  // Batch size
229  int M = X.dim() > 1 ? X.dim32(0) : 1;
230  // Input feature dimension
231  int K = X.numel() / M;
232  const auto* labeldata = label.data<int>();
233 
234  auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_);
235  int output_offset = getIntermediateOutputSize(labeldata, M, hierarchy);
236 
237  //Traverse backward to access intermediate_output generated by HSoftmaxOp
238  // sequentially in reverse order
239  for (int sample = M-1; sample >= 0; sample--) {
240  int word_id = labeldata[sample];
241  PathProto path = hierarchy[word_id];
242  for (auto node = path.path_nodes().rbegin();
243  node != path.path_nodes().rend(); node++) {
244  int w_offset = node->index();
245  int w_length = node->length();
246  int target = node->target();
247  RunBackwardSingle(X.data<float>() + sample*K, dY.data<float>() + sample,
248  W.data<float>() + w_offset*K, target, intermediate_output.data<float>(),
249  dX_data + sample*K, dW_data + w_offset*K, db_data + w_offset,
250  dOutput_data, K, w_length, output_offset);
251  }
252  }
253  return true;
254 }
255 
256 // Implementation for the CPU context.
257 template <>
258 bool HSoftmaxSearchOp<float, CPUContext>::pruning(
259  const float* X,
260  int sample,
261  int K,
262  const float* W,
263  const float* b,
264  const NodeProto& src_node,
265  NodeProto& dst_node,
266  float parent_score,
267  float beam) {
268  int w_length = src_node.children_size() + src_node.word_ids_size();
269  Tensor intermediate_data{CPU};
270  intermediate_data.Resize(2 * w_length);
271  float* int_output_data = intermediate_data.template mutable_data<float>();
272  int int_output_offset = 0;
273  int w_offset = src_node.offset();
274 
275  RunForwardSingle(
276  X + K * sample,
277  W + w_offset * K,
278  b + w_offset,
279  -1,
280  int_output_data,
281  bias_multiplier_->template data<float>() + sample,
282  w_length,
283  K,
284  int_output_offset);
285 
286  float* softmax_output_data = int_output_data + w_length;
287  // real probabilities
288  for (int i = 0; i < w_length; i++) {
289  softmax_output_data[i] =
290  -log(std::max(softmax_output_data[i], kLOG_THRESHOLD())) + parent_score;
291  }
292  for (int i = 0; i < src_node.children_size(); i++) {
293  if (softmax_output_data[i] < parent_score + beam) {
294  dst_node.add_children();
295  int idx = dst_node.children_size() - 1;
296  CAFFE_ENFORCE(
297  src_node.children(i).has_offset(),
298  "HSM Search require the field offset in NodeProte");
299  dst_node.mutable_children(idx)->set_offset(src_node.children(i).offset());
300  CAFFE_ENFORCE(
301  src_node.children(i).has_name(),
302  "HSM Search require the field name in NodeProte");
303  dst_node.mutable_children(idx)->set_name(src_node.children(i).name());
304  dst_node.add_scores(softmax_output_data[i]);
305  pruning(
306  X,
307  sample,
308  K,
309  W,
310  b,
311  src_node.children(i),
312  *dst_node.mutable_children(idx),
313  softmax_output_data[i],
314  beam);
315  }
316  }
317 
318  for (int i = src_node.children_size(); i < w_length; i++) {
319  if (softmax_output_data[i] < parent_score + beam) {
320  dst_node.add_word_ids(src_node.word_ids(i - src_node.children_size()));
321  dst_node.add_scores(softmax_output_data[i]);
322  }
323  }
324 
325  return true;
326 }
327 
328 template <>
329 bool HSoftmaxSearchOp<float, CPUContext>::extractNodes(
330  const NodeProto& node,
331  std::vector<std::pair<string, float>>& info) {
332  int i = 0;
333 
334  for (const auto& n : node.children()) {
335  info.emplace_back(std::make_pair(n.name(), node.scores(i++)));
336  }
337  for (const int n : node.word_ids()) {
338  info.emplace_back(std::make_pair(c10::to_string(n), node.scores(i++)));
339  }
340 
341  for (const auto& n : node.children()) {
342  extractNodes(n, info);
343  }
344  return true;
345 }
346 
347 // Implementation for the CPU context.
348 template <>
349 bool HSoftmaxSearchOp<float, CPUContext>::RunOnDevice() {
350  auto& X = Input(0);
351  const auto& W = Input(1);
352  const auto& b = Input(2);
353 
354  // Batch size
355  int M = X.dim() > 1 ? X.dim32(0) : 1;
356  // Input feature dimension
357  int K = X.numel() / M;
358  CAFFE_ENFORCE(W.dim() == 2, "Weight must be a matrix."); // N*K
359  CAFFE_ENFORCE(b.dim() == 1, "Bias must be a vector."); // N
360  CAFFE_ENFORCE(K == W.numel() / (W.dim32(0)), "feature dimension mismatch.");
361  // Sum of output dimensions of all hierarchy nodes
362  int N = W.dim32(0);
363  CAFFE_ENFORCE(N == b.dim32(0), "mismatch between Weight and Bias.");
364  auto* Y_names = Output(0, {M, top_n_}, at::dtype<string>());
365  auto* Y_scores = Output(1, {M, top_n_}, at::dtype<float>());
366 
367  if (!bias_multiplier_.has_value()) {
368  bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
369  math::Set<float, CPUContext>(M, static_cast<float>(1),
370  bias_multiplier_->mutable_data<float>(), &context_);
371  } else if (bias_multiplier_->numel() != M) {
372  bias_multiplier_->Resize(M);
373  math::Set<float, CPUContext>(M, static_cast<float>(1),
374  bias_multiplier_->mutable_data<float>(), &context_);
375  }
376 
377  for (int sample = 0; sample < M; ++sample) {
378  CAFFE_ENFORCE(
379  tree_.root_node().has_offset(),
380  "HSM Search require the field offset in NodeProte");
381  CAFFE_ENFORCE(
382  tree_.root_node().has_name(),
383  "HSM Search require the field name in NodeProte");
384 
385  NodeProto dst_node;
386  dst_node.set_offset(tree_.root_node().offset());
387  dst_node.set_name(tree_.root_node().name());
388 
389  pruning(
390  X.data<float>(),
391  sample,
392  K,
393  W.data<float>(),
394  b.data<float>(),
395  tree_.root_node(),
396  dst_node,
397  0,
398  beam_);
399 
400  std::vector<std::pair<string, float>> info;
401  extractNodes(dst_node, info);
402  // saving the results for each sample.
403  std::partial_sort(
404  info.begin(),
405  info.begin() + (top_n_ < info.size() ? top_n_ : info.size() - 1),
406  info.end(),
407  [&](std::pair<string, float> a, std::pair<string, float> b) {
408  return a.second < b.second;
409  });
410  auto* y_name_data =
411  Y_names->template mutable_data<string>() + sample * top_n_;
412  auto* y_score_data =
413  Y_scores->template mutable_data<float>() + sample * top_n_;
414  for (int i = 0; i < top_n_; i++) {
415  if (i < info.size()) {
416  y_name_data[i] = info[i].first;
417  y_score_data[i] = info[i].second;
418  } else {
419  y_score_data[i] = 0;
420  }
421  }
422  }
423 
424  return true;
425 }
426 
427 template <typename T, class Context>
428 bool HuffmanTreeHierarchyOp<T, Context>::RunOnDevice() {
429  const auto& Y = Input(0);
430 
431  CAFFE_ENFORCE_EQ(Y.dim(), 1, "Input labels must be a vector.");
432  const auto y_data = Y.template data<T>();
433  auto treeOutput = Output(0, {1}, at::dtype<string>());
434  std::vector<int> labelCounts;
435  labelCounts.resize(num_classes_, 0);
436  for (int i = 0; i < Y.dim32(0); ++i) {
437  // Labels are in range [0, num_classes]
438  const int label_index = y_data[i];
439  CAFFE_ENFORCE_LT(
440  label_index,
441  num_classes_,
442  "Found an input label ",
443  label_index,
444  " not in range [",
445  0,
446  ",",
447  num_classes_,
448  "]");
449  labelCounts[label_index]++;
450  }
451 
452  std::priority_queue<Node, std::vector<Node>, NodeComparator> nodes;
453  std::vector<Node> huffmanTree;
454  std::vector<int> labelIndices;
455  labelIndices.resize(num_classes_);
456 
457  int current_node_index = 0;
458  for (int i = 0; i < num_classes_; ++i) {
459  Node node(i, labelCounts[i]);
460  nodes.push(node);
461  }
462 
463  // Extract node with minimum count and insert it in the tree array.
464  auto get_next_node = [&nodes, &huffmanTree, &labelIndices]() {
465  auto node = nodes.top();
466  int node_index = huffmanTree.size();
467  if (node.label != -1) {
468  labelIndices[node.label] = node_index;
469  }
470  nodes.pop();
471  huffmanTree.push_back(node);
472  return std::pair<int, Node>(node_index, node);
473  };
474 
475  // Merge two nodes and insert the results in the queue.
476  auto merge_nodes = [&nodes](
477  const std::pair<int, Node>& node_l, const std::pair<int, Node>& node_r) {
478  Node node(-1, node_l.second.count + node_r.second.count);
479  node.left_ch_index = node_l.first;
480  node.right_ch_index = node_r.first;
481  nodes.push(node);
482  };
483 
484  // Main loop for buttom up huffman tree construction.
485  while (!nodes.empty()) {
486  auto lNode = get_next_node();
487  if (!nodes.empty()) {
488  auto rNode = get_next_node();
489  merge_nodes(lNode, rNode);
490  }
491  }
492 
493  auto is_leaf_node = [&huffmanTree](const int node_index) {
494  return huffmanTree[node_index].left_ch_index == -1 &&
495  huffmanTree[node_index].right_ch_index == -1;
496  };
497 
498  auto get_node_label = [&huffmanTree](const int node_index) {
499  return huffmanTree[node_index].label;
500  };
501 
502  // Build huffman tree.
503  int current_offset = 0;
504  std::function<void(int, NodeProto*)> build_tree = [&](
505  const int node_index, NodeProto* node) {
506  if (is_leaf_node(node_index) || node_index == -1) {
507  return;
508  }
509  const int left_ch_index = huffmanTree[node_index].left_ch_index;
510  const int right_ch_index = huffmanTree[node_index].right_ch_index;
511  if (left_ch_index != -1) {
512  if (is_leaf_node(left_ch_index)) {
513  node->add_word_ids(get_node_label(left_ch_index));
514  } else {
515  auto* ch_node = node->add_children();
516  ch_node->set_offset(current_offset);
517  current_offset += 2;
518  build_tree(left_ch_index, ch_node);
519  }
520  }
521  if (right_ch_index != -1) {
522  if (is_leaf_node(right_ch_index)) {
523  node->add_word_ids(get_node_label(right_ch_index));
524  current_offset++;
525  } else {
526  auto* ch_node = node->add_children();
527  ch_node->set_offset(current_offset);
528  current_offset += 2;
529  build_tree(right_ch_index, ch_node);
530  }
531  }
532  };
533 
534  // The last element inserted in the tree is the root.
535  const int rootNodeIndex = huffmanTree.size() - 1;
536  NodeProto rootNode;
537  rootNode.set_offset(current_offset);
538  current_offset += 2;
539  build_tree(rootNodeIndex, &rootNode);
540  TreeProto treeProto;
541  *treeProto.mutable_root_node() = rootNode;
542 
543  treeProto.SerializeToString(treeOutput->template mutable_data<string>());
544  return true;
545 }
546 
547 namespace {
548 REGISTER_CPU_OPERATOR(HSoftmax, HSoftmaxOp<float, CPUContext>);
549 REGISTER_CPU_OPERATOR(HSoftmaxGradient,
550  HSoftmaxGradientOp<float, CPUContext>);
551 REGISTER_CPU_OPERATOR(HSoftmaxSearch, HSoftmaxSearchOp<float, CPUContext>);
552 REGISTER_CPU_OPERATOR(
553  HuffmanTreeHierarchy,
554  HuffmanTreeHierarchyOp<int64_t, CPUContext>);
555 
556 OPERATOR_SCHEMA(HSoftmax)
557  .NumInputs(4)
558  .NumOutputs(2)
559  .SetDoc(R"DOC(
560 Hierarchical softmax is an operator which approximates the softmax operator
561 while giving significant training speed gains and reasonably comparable
562 performance. In this operator, instead of calculating the probabilities of all
563 the classes, we calculate the probability of each step in the path from root to
564 the target word in the hierarchy.
565 
566 The operator takes a 2-D tensor (Tensor) containing a batch of layers, a
567 set of parameters represented by the weight matrix and bias terms, and a 1-D
568 tensor (Tensor) holding labels, or the indices of the target class. The
569 hierarchy has to be specified as an argument to the operator.
570 
571 The operator returns a 1-D tensor holding the computed log probability of the
572 target class and a 2-D tensor of intermediate outputs (from the weight matrix
573 and softmax from each step in the path from root to target class) which will be
574 used by the gradient operator to compute gradients for all samples in the batch.
575 )DOC")
576  .Arg(
577  "hierarchy",
578  "Serialized HierarchyProto string containing list of "
579  "vocabulary words and their paths from root of hierarchy to the leaf")
580  .Input(0, "X", "Input data from previous layer")
581  .Input(
582  1,
583  "W",
584  "2D blob containing 'stacked' fully connected weight "
585  "matrices. Each node in the hierarchy contributes one FC weight matrix if "
586  "it has children nodes. Dimension is N*D, D is input dimension of data (X), "
587  "N is sum of all output dimensions, or total number of nodes (excl root)")
588  .Input(2, "b", "1D blob with N parameters")
589  .Input(3, "labels", "int word_id of the target word")
590  .Output(0, "Y", "1-D of log probability outputs, one per sample")
591  .Output(
592  1,
593  "intermediate_output",
594  "Extra blob to store the intermediate "
595  "FC and softmax outputs for each node in the hierarchical path of a word. "
596  "The outputs from samples are stored in consecutive blocks in the forward "
597  "pass and are used in reverse order in the backward gradientOp pass");
598 
599 OPERATOR_SCHEMA(HSoftmaxGradient).NumInputs(6).NumOutputs(4);
600 
601 class GetHSoftmaxGradient : public GradientMakerBase {
602  using GradientMakerBase::GradientMakerBase;
603  vector<OperatorDef> GetGradientDefs() override {
604  return SingleGradientDef(
605  "HSoftmaxGradient", "",
606  //X, W, b, label, intermediate output, dY
607  vector<string>{I(0), I(1), I(2), I(3), O(1), GO(0)},
608  //dX, dW, db, dintermediate_output
609  vector<string>{GI(0), GI(1), GI(2), GO(1)});
610  }
611 };
612 REGISTER_GRADIENT(HSoftmax, GetHSoftmaxGradient);
613 
614 OPERATOR_SCHEMA(HSoftmaxSearch)
615  .NumInputs(3)
616  .NumOutputs(2)
617  .SetDoc(R"DOC(
618 HSoftmaxSearch is an operator to generate the most possible paths given a
619 well-trained model and input vector. Greedy algorithm is used for pruning the
620 search tree.
621 )DOC")
622  .Arg(
623  "tree",
624  "Serialized TreeProto string containing a tree "
625  "including all intermidate nodes and leafs. All nodes must have names "
626  "for correct outputs")
627  .Arg(
628  "beam",
629  "beam used for pruning tree. The pruning algorithm is that "
630  "only children, whose score is smaller than parent's score puls beam, "
631  "will be propagated. ")
632  .Arg("topN", "Number of nodes in outputs")
633  .Input(0, "X", "Input data from previous layer")
634  .Input(1, "W", "The matrix trained from Softmax Ops")
635  .Input(2, "b", "The bias traiend from Softmax Ops")
636  .Output(
637  0,
638  "Y_names",
639  "The name of selected nodes and leafs. "
640  "For nodes, it will be the name defined in the tree. "
641  "For leafs, it will be the index of the word in the tree.")
642  .Output(1, "Y_scores", "The corresponding scores of Y_names");
643 SHOULD_NOT_DO_GRADIENT(HSoftmaxSearch);
644 
645 OPERATOR_SCHEMA(HuffmanTreeHierarchy)
646  .NumInputs(1)
647  .NumOutputs(1)
648  .SetDoc(R"DOC(
649 HuffmanTreeHierarchy is an operator to generate huffman tree hierarchy given
650 the input labels. It returns the tree as seralized HierarchyProto
651 )DOC")
652  .Arg("num_classes", "The number of classes used to build the hierarchy.")
653  .Input(0, "Labels", "The labels vector")
654  .Output(0, "Hierarch", "Huffman coding hierarchy of the labels");
655 
656 SHOULD_NOT_DO_GRADIENT(HuffmanTreeHierarchyOp);
657 } // namespace
658 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13