Caffe2 - C++ API
A deep learning, cross platform ML framework
h_softmax_op.h
1 
17 #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
18 #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/proto/hsm.pb.h"
24 #include "caffe2/utils/math.h"
25 
26 namespace caffe2 {
27 
28 template <typename T, typename Context>
29 class HSoftmaxOpBase : public Operator<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  HSoftmaxOpBase(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws) {
34  HierarchyProto hierarchy;
35  CAFFE_ENFORCE(hierarchy.ParseFromString(
36  OperatorBase::GetSingleArgument<string>("hierarchy", "")));
37  for (const auto& path : hierarchy.paths()) {
38  hierarchy_all_map_.emplace(path.word_id(), path);
39  }
40  }
41 
42  protected:
43  std::unordered_map<int, PathProto> hierarchy_all_map_;
44  Tensor<Context> scale_;
45  Tensor<Context> sum_multiplier_;
46  Tensor<Context> bias_multiplier_;
47  static constexpr T kLOG_THRESHOLD() {
48  return 1e-20f;
49  }
50  static std::unordered_map<int, PathProto> getHierarchyForLabels(
51  int M,
52  const int* labels,
53  const std::unordered_map<int, PathProto>& hierarchy_all_map) {
54  std::unordered_map<int, PathProto> hierarchy_map;
55  std::set<int> label_set = std::set<int>(labels, labels + M);
56  for (const auto& label : label_set) {
57  auto search = hierarchy_all_map.find(label);
58  CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label.");
59  hierarchy_map.emplace(search->first, search->second);
60  }
61  return hierarchy_map;
62  }
63  int getIntermediateOutputSize(
64  const int* labels,
65  int M,
66  std::unordered_map<int, PathProto>& hierarchy) const {
67  int size = 0;
68  for (int label = 0; label < M; ++label) {
69  int word_id = labels[label];
70  const auto& path = hierarchy[word_id];
71  size += std::accumulate(
72  path.path_nodes().begin(),
73  path.path_nodes().end(),
74  0,
75  // Output of FC + Output of Softmax
76  [](int sz, PathNodeProto node) {
77  return sz + 2 * node.length();
78  });
79  }
80  return size;
81  }
82 };
83 
84 template <typename T, class Context>
85 class HSoftmaxOp : public HSoftmaxOpBase<T, Context> {
86  public:
87  USE_OPERATOR_CONTEXT_FUNCTIONS;
89 
90  bool RunOnDevice() override;
91 
92  protected:
93  float RunForwardSingle(
94  const float* X,
95  const float* W,
96  const float* b,
97  int target,
98  float* output,
99  const float* bias_multiplier,
100  int w_length,
101  int K,
102  int& output_offset);
103 };
104 
105 template <typename T, class Context>
106 class HSoftmaxGradientOp final : public HSoftmaxOpBase<T, Context> {
107  public:
108  USE_OPERATOR_CONTEXT_FUNCTIONS;
110  bool RunOnDevice() override;
111 
112  private:
113  void RunBackwardSingle(
114  const float* X,
115  const float* dY,
116  const float* W,
117  int target,
118  const float* int_output,
119  float* dX,
120  float* dW,
121  float* db,
122  float* dOutput,
123  int dim_in,
124  int w_length,
125  int& output_offset);
126 };
127 
128 template <typename T, class Context>
129 class HSoftmaxSearchOp final : public HSoftmaxOp<T, Context> {
130  public:
131  USE_OPERATOR_CONTEXT_FUNCTIONS;
132  HSoftmaxSearchOp(const OperatorDef& operator_def, Workspace* ws)
133  : HSoftmaxOp<T, Context>(operator_def, ws),
134  top_n_(OperatorBase::GetSingleArgument<int>("topN", 5)),
135  beam_(OperatorBase::GetSingleArgument<float>("beam", 0.01f)) {
136  CAFFE_ENFORCE(tree_.ParseFromString(
137  OperatorBase::GetSingleArgument<string>("tree", "")));
138  }
139  bool RunOnDevice() override;
140 
141  private:
142  int top_n_;
143  float beam_;
144  TreeProto tree_;
145  bool pruning(
146  const float* X,
147  int sample,
148  int K,
149  const float* W,
150  const float* b,
151  const NodeProto& src_node,
152  NodeProto& dst_node,
153  float parent_score,
154  float beam);
155  bool extractNodes(
156  const NodeProto& node,
157  std::vector<std::pair<string, float>>& info);
158 };
159 
160 template <typename T, class Context>
161 class HuffmanTreeHierarchyOp : public Operator<Context> {
162  public:
163  USE_OPERATOR_CONTEXT_FUNCTIONS;
164  HuffmanTreeHierarchyOp(const OperatorDef& operator_def, Workspace* ws)
165  : Operator<Context>(operator_def, ws),
166  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", -1)) {}
167  bool RunOnDevice() override;
168 
169  private:
170  // Internal huffman tree data.
171  struct Node {
172  Node(T l, int count)
173  : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {}
174  T label;
175  int count;
176  int left_ch_index;
177  int right_ch_index;
178  };
179 
180  struct NodeComparator {
181  bool operator()(const Node& node_a, const Node& node_b) {
182  return node_a.count > node_b.count;
183  }
184  };
185 
186  int num_classes_;
187 };
188 
189 } // namespace caffe2
190 
191 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.