1 #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ 2 #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ 4 #include <c10/util/Optional.h> 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/proto/hsm.pb.h" 9 #include "caffe2/utils/math.h" 13 template <
typename T,
typename Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 template <
class... Args>
20 HierarchyProto hierarchy;
21 CAFFE_ENFORCE(hierarchy.ParseFromString(
22 this->template GetSingleArgument<string>(
"hierarchy",
"")));
23 for (
const auto& path : hierarchy.paths()) {
24 hierarchy_all_map_.emplace(path.word_id(), path);
29 std::unordered_map<int, PathProto> hierarchy_all_map_;
33 static constexpr
T kLOG_THRESHOLD() {
36 static std::unordered_map<int, PathProto> getHierarchyForLabels(
39 const std::unordered_map<int, PathProto>& hierarchy_all_map) {
40 std::unordered_map<int, PathProto> hierarchy_map;
41 std::set<int> label_set = std::set<int>(labels, labels + M);
42 for (
const auto& label : label_set) {
43 auto search = hierarchy_all_map.find(label);
44 CAFFE_ENFORCE(search != hierarchy_all_map.end(),
"incorrect label.");
45 hierarchy_map.emplace(search->first, search->second);
49 int getIntermediateOutputSize(
52 std::unordered_map<int, PathProto>& hierarchy)
const {
54 for (
int label = 0; label < M; ++label) {
55 int word_id = labels[label];
56 const auto& path = hierarchy[word_id];
57 size += std::accumulate(
58 path.path_nodes().begin(),
59 path.path_nodes().end(),
62 [](
int sz, PathNodeProto node) {
63 return sz + 2 * node.length();
70 template <
typename T,
class Context>
73 USE_OPERATOR_CONTEXT_FUNCTIONS;
76 bool RunOnDevice()
override;
79 float RunForwardSingle(
85 const float* bias_multiplier,
91 template <
typename T,
class Context>
94 USE_OPERATOR_CONTEXT_FUNCTIONS;
96 bool RunOnDevice()
override;
99 void RunBackwardSingle(
104 const float* int_output,
114 template <
typename T,
class Context>
117 USE_OPERATOR_CONTEXT_FUNCTIONS;
118 template <
class... Args>
121 top_n_(this->
template GetSingleArgument<int>(
"topN", 5)),
122 beam_(this->
template GetSingleArgument<float>(
"beam", 0.01f)) {
123 CAFFE_ENFORCE(tree_.ParseFromString(
124 this->template GetSingleArgument<string>(
"tree",
"")));
126 bool RunOnDevice()
override;
138 const NodeProto& src_node,
143 const NodeProto& node,
144 std::vector<std::pair<string, float>>& info);
147 template <
typename T,
class Context>
150 USE_OPERATOR_CONTEXT_FUNCTIONS;
151 template <
class... Args>
154 num_classes_(this->
template GetSingleArgument<int>(
"num_classes", -1)) {
156 bool RunOnDevice()
override;
162 : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {}
169 struct NodeComparator {
170 bool operator()(
const Node& node_a,
const Node& node_b) {
171 return node_a.count > node_b.count;
180 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...