Caffe2 - C++ API
A deep learning, cross platform ML framework
h_softmax_op.h
1 #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
2 #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
3 
4 #include <c10/util/Optional.h>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/proto/hsm.pb.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 template <typename T, typename Context>
14 class HSoftmaxOpBase : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  template <class... Args>
18  explicit HSoftmaxOpBase(Args&&... args)
19  : Operator<Context>(std::forward<Args>(args)...) {
20  HierarchyProto hierarchy;
21  CAFFE_ENFORCE(hierarchy.ParseFromString(
22  this->template GetSingleArgument<string>("hierarchy", "")));
23  for (const auto& path : hierarchy.paths()) {
24  hierarchy_all_map_.emplace(path.word_id(), path);
25  }
26  }
27 
28  protected:
29  std::unordered_map<int, PathProto> hierarchy_all_map_;
30  c10::optional<Tensor> scale_;
31  c10::optional<Tensor> sum_multiplier_;
32  c10::optional<Tensor> bias_multiplier_;
33  static constexpr T kLOG_THRESHOLD() {
34  return 1e-20f;
35  }
36  static std::unordered_map<int, PathProto> getHierarchyForLabels(
37  int M,
38  const int* labels,
39  const std::unordered_map<int, PathProto>& hierarchy_all_map) {
40  std::unordered_map<int, PathProto> hierarchy_map;
41  std::set<int> label_set = std::set<int>(labels, labels + M);
42  for (const auto& label : label_set) {
43  auto search = hierarchy_all_map.find(label);
44  CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label.");
45  hierarchy_map.emplace(search->first, search->second);
46  }
47  return hierarchy_map;
48  }
49  int getIntermediateOutputSize(
50  const int* labels,
51  int M,
52  std::unordered_map<int, PathProto>& hierarchy) const {
53  int size = 0;
54  for (int label = 0; label < M; ++label) {
55  int word_id = labels[label];
56  const auto& path = hierarchy[word_id];
57  size += std::accumulate(
58  path.path_nodes().begin(),
59  path.path_nodes().end(),
60  0,
61  // Output of FC + Output of Softmax
62  [](int sz, PathNodeProto node) {
63  return sz + 2 * node.length();
64  });
65  }
66  return size;
67  }
68 };
69 
70 template <typename T, class Context>
71 class HSoftmaxOp : public HSoftmaxOpBase<T, Context> {
72  public:
73  USE_OPERATOR_CONTEXT_FUNCTIONS;
75 
76  bool RunOnDevice() override;
77 
78  protected:
79  float RunForwardSingle(
80  const float* X,
81  const float* W,
82  const float* b,
83  int target,
84  float* output,
85  const float* bias_multiplier,
86  int w_length,
87  int K,
88  int& output_offset);
89 };
90 
91 template <typename T, class Context>
92 class HSoftmaxGradientOp final : public HSoftmaxOpBase<T, Context> {
93  public:
94  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  bool RunOnDevice() override;
97 
98  private:
99  void RunBackwardSingle(
100  const float* X,
101  const float* dY,
102  const float* W,
103  int target,
104  const float* int_output,
105  float* dX,
106  float* dW,
107  float* db,
108  float* dOutput,
109  int dim_in,
110  int w_length,
111  int& output_offset);
112 };
113 
114 template <typename T, class Context>
115 class HSoftmaxSearchOp final : public HSoftmaxOp<T, Context> {
116  public:
117  USE_OPERATOR_CONTEXT_FUNCTIONS;
118  template <class... Args>
119  explicit HSoftmaxSearchOp(Args&&... args)
120  : HSoftmaxOp<T, Context>(std::forward<Args>(args)...),
121  top_n_(this->template GetSingleArgument<int>("topN", 5)),
122  beam_(this->template GetSingleArgument<float>("beam", 0.01f)) {
123  CAFFE_ENFORCE(tree_.ParseFromString(
124  this->template GetSingleArgument<string>("tree", "")));
125  }
126  bool RunOnDevice() override;
127 
128  private:
129  int top_n_;
130  float beam_;
131  TreeProto tree_;
132  bool pruning(
133  const float* X,
134  int sample,
135  int K,
136  const float* W,
137  const float* b,
138  const NodeProto& src_node,
139  NodeProto& dst_node,
140  float parent_score,
141  float beam);
142  bool extractNodes(
143  const NodeProto& node,
144  std::vector<std::pair<string, float>>& info);
145 };
146 
147 template <typename T, class Context>
148 class HuffmanTreeHierarchyOp : public Operator<Context> {
149  public:
150  USE_OPERATOR_CONTEXT_FUNCTIONS;
151  template <class... Args>
152  explicit HuffmanTreeHierarchyOp(Args&&... args)
153  : Operator<Context>(std::forward<Args>(args)...),
154  num_classes_(this->template GetSingleArgument<int>("num_classes", -1)) {
155  }
156  bool RunOnDevice() override;
157 
158  private:
159  // Internal huffman tree data.
160  struct Node {
161  Node(T l, int count)
162  : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {}
163  T label;
164  int count;
165  int left_ch_index;
166  int right_ch_index;
167  };
168 
169  struct NodeComparator {
170  bool operator()(const Node& node_a, const Node& node_b) {
171  return node_a.count > node_b.count;
172  }
173  };
174 
175  int num_classes_;
176 };
177 
178 } // namespace caffe2
179 
180 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13