Caffe2 - Python API
A deep learning, cross platform ML framework
batch_mse_loss.py
1 ## @package batch_mse_loss
2 # Module caffe2.python.layers.batch_mse_loss
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchMSELoss(ModelLayer):
19 
20  def __init__(self, model, input_record, name='batch_mse_loss', **kwargs):
21  super(BatchMSELoss, self).__init__(model, name, input_record, **kwargs)
22 
23  assert schema.is_schema_subset(
25  ('label', schema.Scalar()),
26  ('prediction', schema.Scalar())
27  ),
28  input_record
29  )
30  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
31 
33  np.float32,
34  self.get_next_blob_reference('output'))
35 
36  def add_ops(self, net):
37  prediction = net.Squeeze(
38  self.input_record.prediction(),
39  net.NextScopedBlob('squeezed_prediction'),
40  dims=[1]
41  )
42 
43  label = self.input_record.label.field_blobs()
44  if self.input_record.label.field_type().base != (
45  self.input_record.prediction.field_type().base):
46 
47  label = net.Cast(
48  label,
49  net.NextScopedBlob('cast_label'),
50  to=schema.data_type_for_dtype(
51  self.input_record.prediction.field_type()
52  )
53  )
54 
55  label = net.StopGradient(
56  label,
57  net.NextScopedBlob('stopped_label')
58  )
59 
60  l2dist = net.SquaredL2Distance(
61  [label, prediction],
62  net.NextScopedBlob('l2')
63  )
64 
65  if 'weight' in self.input_record.fields:
66  weight_blob = self.input_record.weight()
67  if self.input_record.weight.field_type().base != np.float32:
68  weight_blob = net.Cast(
69  weight_blob,
70  weight_blob + '_float32',
71  to=core.DataType.FLOAT
72  )
73  weight_blob = net.StopGradient(
74  [weight_blob],
75  [net.NextScopedBlob('weight_stop_gradient')],
76  )
77  l2dist = net.Mul(
78  [l2dist, weight_blob],
79  net.NextScopedBlob('weighted_l2_distance'),
80  )
81 
82  net.AveragedLoss(l2dist, self.output_schema.field_blobs())