Caffe2 - Python API
A deep learning, cross platform ML framework
hsm_util.py
1 ## @package hsm_util
2 # Module caffe2.python.hsm_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import hsm_pb2
9 
10 '''
11  Hierarchical softmax utility methods that can be used to:
12  1) create TreeProto structure given list of word_ids or NodeProtos
13  2) create HierarchyProto structure using the user-inputted TreeProto
14 '''
15 
16 
17 def create_node_with_words(words, name='node'):
18  node = hsm_pb2.NodeProto()
19  node.name = name
20  for word in words:
21  node.word_ids.append(word)
22  return node
23 
24 
25 def create_node_with_nodes(nodes, name='node'):
26  node = hsm_pb2.NodeProto()
27  node.name = name
28  for child_node in nodes:
29  new_child_node = node.children.add()
30  new_child_node.MergeFrom(child_node)
31  return node
32 
33 
34 def create_hierarchy(tree_proto):
35  max_index = 0
36 
37  def create_path(path, word):
38  path_proto = hsm_pb2.PathProto()
39  path_proto.word_id = word
40  for entry in path:
41  new_path_node = path_proto.path_nodes.add()
42  new_path_node.index = entry[0]
43  new_path_node.length = entry[1]
44  new_path_node.target = entry[2]
45  return path_proto
46 
47  def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
48  node_proto.offset = max_index
49  path.append([max_index,
50  len(node_proto.word_ids) + len(node_proto.children), 0])
51  max_index += len(node_proto.word_ids) + len(node_proto.children)
52  if hierarchy_proto.size < max_index:
53  hierarchy_proto.size = max_index
54  for target, node in enumerate(node_proto.children):
55  path[-1][2] = target
56  max_index = recursive_path_builder(node, path, hierarchy_proto,
57  max_index)
58  for target, word in enumerate(node_proto.word_ids):
59  path[-1][2] = target + len(node_proto.children)
60  path_entry = create_path(path, word)
61  new_path_entry = hierarchy_proto.paths.add()
62  new_path_entry.MergeFrom(path_entry)
63  del path[-1]
64  return max_index
65 
66  node = tree_proto.root_node
67  hierarchy_proto = hsm_pb2.HierarchyProto()
68  path = []
69  max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
70  return hierarchy_proto