Caffe2 - Python API
A deep learning, cross platform ML framework
merge_id_lists.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 from caffe2.python import schema
22 from caffe2.python.layers.layers import (
23  get_categorical_limit,
24  ModelLayer,
25  IdList
26 )
27 
28 import numpy as np
29 
30 
31 class MergeIdLists(ModelLayer):
32  """Merge multiple ID_LISTs into a single ID_LIST
33 
34  Arguments:
35  model: A layer model instance
36  input_record: Tuple (Struct) of ID_LIST features to be
37  merged
38 
39  Returns:
40  the merged ID_LIST feature
41  """
42  def __init__(self, model, input_record, name='merged'):
43  super(MergeIdLists, self).__init__(model, name, input_record)
44  assert all(schema.equal_schemas(x, IdList) for x in input_record), \
45  "Inputs to MergeIdLists should all be IdLists."
46 
47  assert all(record.items.metadata is not None
48  for record in self.input_record), \
49  "Features without metadata are not supported"
50 
51  merge_dim = max(get_categorical_limit(record)
52  for record in self.input_record)
53  assert merge_dim is not None, "Unbounded features are not supported"
54 
55  self.output_schema = schema.NewRecord(
56  model.net, schema.List(
58  np.int64,
59  blob=model.net.NextBlob(name),
60  metadata=schema.Metadata(categorical_limit=merge_dim)
61  )))
62 
63  def add_ops(self, net):
64  return net.MergeIdLists(self.input_record.field_blobs(),
65  self.output_schema.field_blobs())