Caffe2 - Python API
A deep learning, cross platform ML framework
hsm_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package hsm_util
17 # Module caffe2.python.hsm_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.proto import hsm_pb2
24 
25 '''
26  Hierarchical softmax utility methods that can be used to:
27  1) create TreeProto structure given list of word_ids or NodeProtos
28  2) create HierarchyProto structure using the user-inputted TreeProto
29 '''
30 
31 
32 def create_node_with_words(words, name='node'):
33  node = hsm_pb2.NodeProto()
34  node.name = name
35  for word in words:
36  node.word_ids.append(word)
37  return node
38 
39 
40 def create_node_with_nodes(nodes, name='node'):
41  node = hsm_pb2.NodeProto()
42  node.name = name
43  for child_node in nodes:
44  new_child_node = node.children.add()
45  new_child_node.MergeFrom(child_node)
46  return node
47 
48 
49 def create_hierarchy(tree_proto):
50  max_index = 0
51 
52  def create_path(path, word):
53  path_proto = hsm_pb2.PathProto()
54  path_proto.word_id = word
55  for entry in path:
56  new_path_node = path_proto.path_nodes.add()
57  new_path_node.index = entry[0]
58  new_path_node.length = entry[1]
59  new_path_node.target = entry[2]
60  return path_proto
61 
62  def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
63  node_proto.offset = max_index
64  path.append([max_index,
65  len(node_proto.word_ids) + len(node_proto.children), 0])
66  max_index += len(node_proto.word_ids) + len(node_proto.children)
67  if hierarchy_proto.size < max_index:
68  hierarchy_proto.size = max_index
69  for target, node in enumerate(node_proto.children):
70  path[-1][2] = target
71  max_index = recursive_path_builder(node, path, hierarchy_proto,
72  max_index)
73  for target, word in enumerate(node_proto.word_ids):
74  path[-1][2] = target + len(node_proto.children)
75  path_entry = create_path(path, word)
76  new_path_entry = hierarchy_proto.paths.add()
77  new_path_entry.MergeFrom(path_entry)
78  del path[-1]
79  return max_index
80 
81  node = tree_proto.root_node
82  hierarchy_proto = hsm_pb2.HierarchyProto()
83  path = []
84  max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
85  return hierarchy_proto