3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
20 def __init__(self, model, input_record, name='margin_rank_loss',
21 margin=0.1, average_loss=
False, **kwargs):
22 super(MarginRankLoss, self).__init__(model, name, input_record, **kwargs)
23 assert margin >= 0, (
'For hinge loss, margin should be no less than 0')
26 assert schema.is_schema_subset(
33 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
36 self.get_next_blob_reference(
'output'))
38 def add_ops(self, net):
39 neg_score = self.input_record.neg_prediction[
'values']()
41 pos_score = net.LengthsTile(
43 self.input_record.pos_prediction(),
44 self.input_record.neg_prediction[
'lengths']()
46 net.NextScopedBlob(
'pos_score_repeated')
48 const_1 = net.ConstantFill(
50 net.NextScopedBlob(
'const_1'),
52 dtype=core.DataType.INT32
54 rank_loss = net.MarginRankingCriterion(
55 [pos_score, neg_score, const_1],
56 net.NextScopedBlob(
'rank_loss'),
60 net.AveragedLoss(rank_loss, self.output_schema.field_blobs())
62 net.ReduceFrontSum(rank_loss, self.output_schema.field_blobs())