Caffe2 - Python API
A deep learning, cross platform ML framework
batch_mse_loss.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package batch_mse_loss
17 # Module caffe2.python.layers.batch_mse_loss
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 from caffe2.python.layers.tags import (
28  Tags
29 )
30 import numpy as np
31 
32 
33 class BatchMSELoss(ModelLayer):
34 
35  def __init__(self, model, input_record, name='batch_mse_loss', **kwargs):
36  super(BatchMSELoss, self).__init__(model, name, input_record, **kwargs)
37 
38  assert schema.is_schema_subset(
40  ('label', schema.Scalar()),
41  ('prediction', schema.Scalar())
42  ),
43  input_record
44  )
45  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
46 
48  np.float32,
49  self.get_next_blob_reference('output'))
50 
51  def add_ops(self, net):
52  prediction = net.Squeeze(
53  self.input_record.prediction(),
54  net.NextScopedBlob('squeezed_prediction'),
55  dims=[1]
56  )
57 
58  label = self.input_record.label.field_blobs()
59  if self.input_record.label.field_type().base != (
60  self.input_record.prediction.field_type().base):
61 
62  label = net.Cast(
63  label,
64  net.NextScopedBlob('cast_label'),
65  to=schema.data_type_for_dtype(
66  self.input_record.prediction.field_type()
67  )
68  )
69 
70  label = net.StopGradient(
71  label,
72  net.NextScopedBlob('stopped_label')
73  )
74 
75  l2dist = net.SquaredL2Distance(
76  [label, prediction],
77  net.NextScopedBlob('l2')
78  )
79 
80  if 'weight' in self.input_record.fields:
81  weight_blob = self.input_record.weight()
82  if self.input_record.weight.field_type().base != np.float32:
83  weight_blob = net.Cast(
84  weight_blob,
85  weight_blob + '_float32',
86  to=core.DataType.FLOAT
87  )
88  weight_blob = net.StopGradient(
89  [weight_blob],
90  [net.NextScopedBlob('weight_stop_gradient')],
91  )
92  l2dist = net.Mul(
93  [l2dist, weight_blob],
94  net.NextScopedBlob('weighted_l2_distance'),
95  )
96 
97  net.AveragedLoss(l2dist, self.output_schema.field_blobs())