3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
20 def __init__(self, model, input_record, name='batch_mse_loss', **kwargs):
21 super(BatchMSELoss, self).__init__(model, name, input_record, **kwargs)
23 assert schema.is_schema_subset(
30 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
34 self.get_next_blob_reference(
'output'))
36 def add_ops(self, net):
37 prediction = net.Squeeze(
38 self.input_record.prediction(),
39 net.NextScopedBlob(
'squeezed_prediction'),
43 label = self.input_record.label.field_blobs()
44 if self.input_record.label.field_type().base != (
45 self.input_record.prediction.field_type().base):
49 net.NextScopedBlob(
'cast_label'),
50 to=schema.data_type_for_dtype(
51 self.input_record.prediction.field_type()
55 label = net.StopGradient(
57 net.NextScopedBlob(
'stopped_label')
60 l2dist = net.SquaredL2Distance(
62 net.NextScopedBlob(
'l2')
65 if 'weight' in self.input_record.fields:
66 weight_blob = self.input_record.weight()
67 if self.input_record.weight.field_type().base != np.float32:
68 weight_blob = net.Cast(
70 weight_blob +
'_float32',
71 to=core.DataType.FLOAT
73 weight_blob = net.StopGradient(
75 [net.NextScopedBlob(
'weight_stop_gradient')],
78 [l2dist, weight_blob],
79 net.NextScopedBlob(
'weighted_l2_distance'),
82 net.AveragedLoss(l2dist, self.output_schema.field_blobs())