3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import hsm_pb2
11 Hierarchical softmax utility methods that can be used to: 12 1) create TreeProto structure given list of word_ids or NodeProtos 13 2) create HierarchyProto structure using the user-inputted TreeProto 17 def create_node_with_words(words, name='node'):
18 node = hsm_pb2.NodeProto()
21 node.word_ids.append(word)
25 def create_node_with_nodes(nodes, name='node'):
26 node = hsm_pb2.NodeProto()
28 for child_node
in nodes:
29 new_child_node = node.children.add()
30 new_child_node.MergeFrom(child_node)
34 def create_hierarchy(tree_proto):
37 def create_path(path, word):
38 path_proto = hsm_pb2.PathProto()
39 path_proto.word_id = word
41 new_path_node = path_proto.path_nodes.add()
42 new_path_node.index = entry[0]
43 new_path_node.length = entry[1]
44 new_path_node.target = entry[2]
47 def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
48 node_proto.offset = max_index
49 path.append([max_index,
50 len(node_proto.word_ids) + len(node_proto.children), 0])
51 max_index += len(node_proto.word_ids) + len(node_proto.children)
52 if hierarchy_proto.size < max_index:
53 hierarchy_proto.size = max_index
54 for target, node
in enumerate(node_proto.children):
56 max_index = recursive_path_builder(node, path, hierarchy_proto,
58 for target, word
in enumerate(node_proto.word_ids):
59 path[-1][2] = target + len(node_proto.children)
60 path_entry = create_path(path, word)
61 new_path_entry = hierarchy_proto.paths.add()
62 new_path_entry.MergeFrom(path_entry)
66 node = tree_proto.root_node
67 hierarchy_proto = hsm_pb2.HierarchyProto()
69 max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
70 return hierarchy_proto