Caffe2 - Python API
A deep learning, cross platform ML framework
batch_distill_lr_loss.py
1 ## @package batch_distill_lr_loss
2 # Module caffe2.python.layers.batch_distill_lr_loss
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchDistillLRLoss(ModelLayer):
19 
20  def __init__(
21  self, model, input_record,
22  name='batch_distill_lr_loss', teacher_weight=0.0,
23  filter_invalid_teacher_label=False, **kwargs):
24 
25  super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
26 
27  assert teacher_weight >= 0 and teacher_weight <= 1, (
28  'teacher_weight=%0.2f should be in [0, 1]' % teacher_weight
29  )
30 
31  self._teacher_weight = teacher_weight
32  self._filter_invalid_teacher_label = filter_invalid_teacher_label
33  # hyper-parameter determines whether to filter out bad teacehr labels,
34  # i.e., teacher labels that are zero.
36  self.threshold = model.add_global_constant(
37  str(model.net.NextScopedBlob('threshold')),
38  [0.0], # threshold for filtering teacher weight.
39  dtype=np.float
40  )
41  self.neg_ONE = model.add_global_constant(
42  str(model.net.NextScopedBlob('neg_ONE')),
43  [-1.0],
44  dtype=np.float
45  )
46  self.ONE = model._GetOne()
47  assert schema.is_schema_subset(
49  ('teacher_label', schema.Scalar()),
50  ('label', schema.Scalar()),
51  ('logit', schema.Scalar()),
52  ),
53  input_record
54  )
55  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
56 
58  np.float32,
59  self.get_next_blob_reference('output')
60  )
61 
62  def add_ops(self, net):
63  label = self.input_record.label()
64  if self.input_record.label.field_type() != np.float32:
65  label = net.Cast(
66  label,
67  net.NextScopedBlob('float_label'),
68  to=core.DataType.FLOAT,
69  )
70 
71  # Assuming 1-D input
72  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
73  dims=[1])
74 
75  teacher_label = self.input_record.teacher_label()
76 
77  if self.input_record.teacher_label.field_type() != np.float32:
78  teacher_label = net.Cast(
79  teacher_label,
80  net.NextScopedBlob('float_teacher_label'),
81  to=core.DataType.FLOAT,
82  )
83  teacher_label = net.ExpandDims(
84  teacher_label, net.NextScopedBlob('expanded_teacher_label'),
85  dims=[1])
86 
87  true_xent = net.SigmoidCrossEntropyWithLogits(
88  [self.input_record.logit(), label],
89  net.NextScopedBlob('cross_entropy')
90  )
91 
92  teacher_xent = net.SigmoidCrossEntropyWithLogits(
93  [self.input_record.logit(), teacher_label],
94  net.NextScopedBlob('teacher_cross_entropy')
95  )
97  squeezed_teacher_label = net.Squeeze(
98  teacher_label,
99  net.NextScopedBlob('squeezed_teacher_label'),
100  dims=[1]
101  )
102  # blob used to contain the original teacher weights
103  keep_weights = net.ConstantFill(
104  [squeezed_teacher_label],
105  net.NextScopedBlob('keep_weights'),
106  value=self._teacher_weight,
107  dtype=core.DataType.FLOAT
108  )
109  #blob used to zero out the teacher weights
110  zero_weights = net.ConstantFill(
111  [squeezed_teacher_label],
112  net.NextScopedBlob('zero_weights'),
113  value=0.0,
114  dtype=core.DataType.FLOAT
115  )
116 
117  #Indicating which teacher labels are bad, i.e., are zero.
118  judge = net.GT(
119  [squeezed_teacher_label, self.threshold],
120  net.NextScopedBlob('judge'),
121  broadcast=1
122  )
123  #zero out bad teacher weights corresponding to bad teacher labels.
124  screened_teacher_weights = net.Conditional(
125  [judge, keep_weights, zero_weights],
126  net.NextScopedBlob('screened_teacher_weights')
127  )
128  neg_screened_teacher_weights = net.Mul(
129  [screened_teacher_weights, self.neg_ONE],
130  net.NextScopedBlob('neg_screened_teacher_weights'),
131  broadcast=1
132  )
133  one_minus_screened_teacher_weights = net.Add(
134  [neg_screened_teacher_weights, self.ONE],
135  net.NextScopedBlob('one_minus_screened_teacher_weights'),
136  broadcast=1
137  )
138  scaled_true_xent = net.Mul(
139  [true_xent, one_minus_screened_teacher_weights],
140  net.NextScopedBlob('scaled_cross_entropy'),
141  broadcast=1
142  )
143  scaled_teacher_xent = net.Mul(
144  [teacher_xent, screened_teacher_weights],
145  net.NextScopedBlob('scaled_teacher_cross_entropy'),
146  broadcast=1
147  )
148  else:
149  scaled_true_xent = net.Scale(
150  true_xent,
151  net.NextScopedBlob('scaled_cross_entropy'),
152  scale=float(1.0 - self._teacher_weight),
153  )
154  scaled_teacher_xent = net.Scale(
155  teacher_xent,
156  net.NextScopedBlob('scaled_teacher_cross_entropy'),
157  scale=float(self._teacher_weight),
158  )
159  if 'weight' in self.input_record.fields:
160  weight_blob = self.input_record.weight()
161  if self.input_record.weight.field_type().base != np.float32:
162  weight_blob = net.Cast(
163  weight_blob,
164  weight_blob + '_float32',
165  to=core.DataType.FLOAT
166  )
167  weight_blob = net.StopGradient(
168  [weight_blob],
169  [net.NextScopedBlob('weight_stop_gradient')],
170  )
171  scaled_true_xent = net.Mul(
172  [scaled_true_xent, weight_blob],
173  net.NextScopedBlob('weighted_xent_label'),
174  )
175  scaled_teacher_xent = net.Mul(
176  [scaled_teacher_xent, weight_blob],
177  net.NextScopedBlob('weighted_xent_teacher'),
178  )
179 
180  true_loss = net.AveragedLoss(
181  scaled_true_xent,
182  net.NextScopedBlob('true_loss')
183  )
184  teacher_loss = net.AveragedLoss(
185  scaled_teacher_xent,
186  net.NextScopedBlob('teacher_loss')
187  )
188  net.Add(
189  [true_loss, teacher_loss],
190  self.output_schema.field_blobs()
191  )