3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
16 def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot',
17 name=
'pairwise_similarity', **kwargs):
18 super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs)
20 "Incorrect input type. Excpected Struct, but received: {0}".
23 (
'all_embeddings' in input_record) ^
24 (
'x_embeddings' in input_record
and 'y_embeddings' in input_record)
26 "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
30 if 'all_embeddings' in input_record:
31 x_embeddings = input_record[
'all_embeddings']
32 y_embeddings = input_record[
'all_embeddings']
34 x_embeddings = input_record[
'x_embeddings']
35 y_embeddings = input_record[
'y_embeddings']
38 "Incorrect input type for x. Expected Scalar, " +
39 "but received: {0}".format(x_embeddings))
41 "Incorrect input type for y. Expected Scalar, " +
42 "but received: {0}".format(y_embeddings)
45 if 'indices_to_gather' in input_record:
46 indices_to_gather = input_record[
'indices_to_gather']
48 "Incorrect type of indices_to_gather. " 49 "Expected Scalar, but received: {0}".format(indices_to_gather)
58 dtype = x_embeddings.field_types()[0].base
61 (dtype, (output_dim,)),
62 self.get_next_blob_reference(
'output')
65 def add_ops(self, net):
67 x_embeddings_norm = net.Normalize(self.
x_embeddings(), axis=1)
68 y_embeddings_norm = net.Normalize(self.
y_embeddings(), axis=1)
70 [x_embeddings_norm, y_embeddings_norm],
71 [self.get_next_blob_reference(x_embeddings_norm +
'_matmul')],
77 [self.get_next_blob_reference(self.
x_embeddings() +
'_matmul')],
81 raise NotImplementedError(
82 "pairwise_similarity_func={} is not valid".format(
88 flattened = net.Flatten(