Caffe2 - Python API
A deep learning, cross platform ML framework
All Classes Namespaces Functions
pairwise_similarity.py
1 ## @package dot_product
2 # Module caffe2.python.layers.dot_product
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 
13 
14 class PairwiseSimilarity(ModelLayer):
15 
16  def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
17  name='pairwise_similarity', **kwargs):
18  super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
19  assert isinstance(input_record, schema.Struct), (
20  "Incorrect input type. Excpected Struct, but received: {0}".
21  format(input_record))
22  assert (
23  ('all_embeddings' in input_record) ^
24  ('x_embeddings' in input_record and 'y_embeddings' in input_record)
25  ), (
26  "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
27  "should be given."
28  )
29  self.pairwise_similarity_func = pairwise_similarity_func
30  if 'all_embeddings' in input_record:
31  x_embeddings = input_record['all_embeddings']
32  y_embeddings = input_record['all_embeddings']
33  else:
34  x_embeddings = input_record['x_embeddings']
35  y_embeddings = input_record['y_embeddings']
36 
37  assert isinstance(x_embeddings, schema.Scalar), (
38  "Incorrect input type for x. Expected Scalar, " +
39  "but received: {0}".format(x_embeddings))
40  assert isinstance(y_embeddings, schema.Scalar), (
41  "Incorrect input type for y. Expected Scalar, " +
42  "but received: {0}".format(y_embeddings)
43  )
44 
45  if 'indices_to_gather' in input_record:
46  indices_to_gather = input_record['indices_to_gather']
47  assert isinstance(indices_to_gather, schema.Scalar), (
48  "Incorrect type of indices_to_gather. "
49  "Expected Scalar, but received: {0}".format(indices_to_gather)
50  )
51  self.indices_to_gather = indices_to_gather
52  else:
53  self.indices_to_gather = None
54 
55  self.x_embeddings = x_embeddings
56  self.y_embeddings = y_embeddings
57 
58  dtype = x_embeddings.field_types()[0].base
59 
61  (dtype, (output_dim,)),
62  self.get_next_blob_reference('output')
63  )
64 
65  def add_ops(self, net):
66  if self.pairwise_similarity_func == "cosine_similarity":
67  x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1)
68  y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1)
69  Y = net.BatchMatMul(
70  [x_embeddings_norm, y_embeddings_norm],
71  [self.get_next_blob_reference(x_embeddings_norm + '_matmul')],
72  trans_b=1,
73  )
74  elif self.pairwise_similarity_func == "dot":
75  Y = net.BatchMatMul(
76  [self.x_embeddings(), self.y_embeddings()],
77  [self.get_next_blob_reference(self.x_embeddings() + '_matmul')],
78  trans_b=1,
79  )
80  else:
81  raise NotImplementedError(
82  "pairwise_similarity_func={} is not valid".format(
84  )
85  )
86 
87  if self.indices_to_gather:
88  flattened = net.Flatten(
89  Y, Y + '_flatten',
90  )
91  net.BatchGather(
92  [flattened, self.indices_to_gather()],
93  self.output_schema(),
94  )
95  else:
96  net.Flatten(Y, self.output_schema())