Caffe2 - Python API
A deep learning, cross platform ML framework
merge_id_lists.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import schema
7 from caffe2.python.layers.layers import (
8  get_categorical_limit,
9  ModelLayer,
10  IdList
11 )
12 
13 import numpy as np
14 
15 
16 class MergeIdLists(ModelLayer):
17  """Merge multiple ID_LISTs into a single ID_LIST
18 
19  Arguments:
20  model: A layer model instance
21  input_record: Tuple (Struct) of ID_LIST features to be
22  merged
23 
24  Returns:
25  the merged ID_LIST feature
26  """
27  def __init__(self, model, input_record, name='merged'):
28  super(MergeIdLists, self).__init__(model, name, input_record)
29  assert all(schema.equal_schemas(x, IdList) for x in input_record), \
30  "Inputs to MergeIdLists should all be IdLists."
31 
32  assert all(record.items.metadata is not None
33  for record in self.input_record), \
34  "Features without metadata are not supported"
35 
36  merge_dim = max(get_categorical_limit(record)
37  for record in self.input_record)
38  assert merge_dim is not None, "Unbounded features are not supported"
39 
40  self.output_schema = schema.NewRecord(
41  model.net, schema.List(
43  np.int64,
44  blob=model.net.NextBlob(name),
45  metadata=schema.Metadata(categorical_limit=merge_dim)
46  )))
47 
48  def add_ops(self, net):
49  return net.MergeIdLists(self.input_record.field_blobs(),
50  self.output_schema.field_blobs())