Caffe2 - Python API
A deep learning, cross platform ML framework
pairwise_dot_product.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package dot_product
17 # Module caffe2.python.layers.dot_product
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 
28 
29 class PairwiseDotProduct(ModelLayer):
30 
31  def __init__(self, model, input_record, output_dim,
32  name='pairwise_dot_product', **kwargs):
33  super(PairwiseDotProduct, self).__init__(model, name, input_record, **kwargs)
34  assert isinstance(input_record, schema.Struct), (
35  "Incorrect input type. Excpected Struct, but received: {0}".
36  format(input_record))
37  assert (
38  ('all_embeddings' in input_record) ^
39  ('x_embeddings' in input_record and 'y_embeddings' in input_record)
40  ), (
41  "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
42  "should be given."
43  )
44  if 'all_embeddings' in input_record:
45  x_embeddings = input_record['all_embeddings']
46  y_embeddings = input_record['all_embeddings']
47  else:
48  x_embeddings = input_record['x_embeddings']
49  y_embeddings = input_record['y_embeddings']
50 
51  assert isinstance(x_embeddings, schema.Scalar), (
52  "Incorrect input type for x. Expected Scalar, " +
53  "but received: {0}".format(x_embeddings))
54  assert isinstance(y_embeddings, schema.Scalar), (
55  "Incorrect input type for y. Expected Scalar, " +
56  "but received: {0}".format(y_embeddings)
57  )
58 
59  if 'indices_to_gather' in input_record:
60  indices_to_gather = input_record['indices_to_gather']
61  assert isinstance(indices_to_gather, schema.Scalar), (
62  "Incorrect type of indices_to_gather. "
63  "Expected Scalar, but received: {0}".format(indices_to_gather)
64  )
65  self.indices_to_gather = indices_to_gather
66  else:
67  self.indices_to_gather = None
68 
69  self.x_embeddings = x_embeddings
70  self.y_embeddings = y_embeddings
71 
72  dtype = x_embeddings.field_types()[0].base
73 
75  (dtype, (output_dim,)),
76  self.get_next_blob_reference('output')
77  )
78 
79  def add_ops(self, net):
80  Y = net.BatchMatMul(
81  [self.x_embeddings(), self.y_embeddings()],
82  [self.x_embeddings() + '_matmul'],
83  trans_b=1,
84  )
85 
86  if self.indices_to_gather:
87  flattened = net.Flatten(
88  Y, Y + '_flatten',
89  )
90  net.BatchGather(
91  [flattened, self.indices_to_gather()],
92  self.output_schema(),
93  )
94  else:
95  net.Flatten(Y, self.output_schema())