Caffe2 - Python API
A deep learning, cross platform ML framework
batch_distill_lr_loss.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package batch_distill_lr_loss
17 # Module caffe2.python.layers.batch_distill_lr_loss
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 from caffe2.python.layers.tags import (
28  Tags
29 )
30 import numpy as np
31 
32 
33 class BatchDistillLRLoss(ModelLayer):
34 
35  def __init__(
36  self, model, input_record,
37  name='batch_distill_lr_loss', teacherWeight=0.0, **kwargs):
38 
39  super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
40 
41  assert teacherWeight >= 0 and teacherWeight <= 1, (
42  'teacherWeight=%0.2f should be in [0, 1]' % teacherWeight
43  )
44  self._teacherWeight = teacherWeight
45 
46  assert schema.is_schema_subset(
48  ('teacher_label', schema.Scalar()),
49  ('label', schema.Scalar()),
50  ('logit', schema.Scalar()),
51  ),
52  input_record
53  )
54  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
55 
57  np.float32,
58  self.get_next_blob_reference('output')
59  )
60 
61  def add_ops(self, net):
62  label = self.input_record.label()
63  if self.input_record.label.field_type() != np.float32:
64  label = net.Cast(
65  label,
66  net.NextScopedBlob('float_label'),
67  to=core.DataType.FLOAT,
68  )
69 
70  # Assuming 1-D input
71  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
72  dims=[1])
73 
74  teacher_label = self.input_record.teacher_label()
75  if self.input_record.teacher_label.field_type() != np.float32:
76  teacher_label = net.Cast(
77  teacher_label,
78  net.NextScopedBlob('float_teacher_label'),
79  to=core.DataType.FLOAT,
80  )
81  teacher_label = net.ExpandDims(
82  teacher_label, net.NextScopedBlob('expanded_teacher_label'),
83  dims=[1])
84 
85  true_xent = net.SigmoidCrossEntropyWithLogits(
86  [self.input_record.logit(), label],
87  net.NextScopedBlob('cross_entropy')
88  )
89 
90  teacher_xent = net.SigmoidCrossEntropyWithLogits(
91  [self.input_record.logit(), teacher_label],
92  net.NextScopedBlob('teacher_cross_entropy')
93  )
94 
95  scaled_true_xent = net.Scale(
96  true_xent,
97  net.NextScopedBlob('scaled_cross_entropy'),
98  scale=1.0 - self._teacherWeight,
99  )
100  scaled_teacher_xent = net.Scale(
101  teacher_xent,
102  net.NextScopedBlob('scaled_teacher_cross_entropy'),
103  scale=self._teacherWeight,
104  )
105 
106  true_loss = net.AveragedLoss(
107  scaled_true_xent,
108  net.NextScopedBlob('true_loss')
109  )
110  teacher_loss = net.AveragedLoss(
111  scaled_teacher_xent,
112  net.NextScopedBlob('teacher_loss')
113  )
114 
115  net.Add(
116  [true_loss, teacher_loss],
117  self.output_schema.field_blobs()
118  )