3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
21 self, model, input_record,
22 name=
'batch_distill_lr_loss', teacher_weight=0.0,
23 filter_invalid_teacher_label=
False, **kwargs):
25 super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
27 assert teacher_weight >= 0
and teacher_weight <= 1, (
28 'teacher_weight=%0.2f should be in [0, 1]' % teacher_weight
36 self.
threshold = model.add_global_constant(
37 str(model.net.NextScopedBlob(
'threshold')),
41 self.
neg_ONE = model.add_global_constant(
42 str(model.net.NextScopedBlob(
'neg_ONE')),
46 self.
ONE = model._GetOne()
47 assert schema.is_schema_subset(
55 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
59 self.get_next_blob_reference(
'output')
62 def add_ops(self, net):
63 label = self.input_record.label()
64 if self.input_record.label.field_type() != np.float32:
67 net.NextScopedBlob(
'float_label'),
68 to=core.DataType.FLOAT,
72 label = net.ExpandDims(label, net.NextScopedBlob(
'expanded_label'),
75 teacher_label = self.input_record.teacher_label()
77 if self.input_record.teacher_label.field_type() != np.float32:
78 teacher_label = net.Cast(
80 net.NextScopedBlob(
'float_teacher_label'),
81 to=core.DataType.FLOAT,
83 teacher_label = net.ExpandDims(
84 teacher_label, net.NextScopedBlob(
'expanded_teacher_label'),
87 true_xent = net.SigmoidCrossEntropyWithLogits(
88 [self.input_record.logit(), label],
89 net.NextScopedBlob(
'cross_entropy')
92 teacher_xent = net.SigmoidCrossEntropyWithLogits(
93 [self.input_record.logit(), teacher_label],
94 net.NextScopedBlob(
'teacher_cross_entropy')
97 squeezed_teacher_label = net.Squeeze(
99 net.NextScopedBlob(
'squeezed_teacher_label'),
103 keep_weights = net.ConstantFill(
104 [squeezed_teacher_label],
105 net.NextScopedBlob(
'keep_weights'),
107 dtype=core.DataType.FLOAT
110 zero_weights = net.ConstantFill(
111 [squeezed_teacher_label],
112 net.NextScopedBlob(
'zero_weights'),
114 dtype=core.DataType.FLOAT
119 [squeezed_teacher_label, self.
threshold],
120 net.NextScopedBlob(
'judge'),
124 screened_teacher_weights = net.Conditional(
125 [judge, keep_weights, zero_weights],
126 net.NextScopedBlob(
'screened_teacher_weights')
128 neg_screened_teacher_weights = net.Mul(
129 [screened_teacher_weights, self.
neg_ONE],
130 net.NextScopedBlob(
'neg_screened_teacher_weights'),
133 one_minus_screened_teacher_weights = net.Add(
134 [neg_screened_teacher_weights, self.
ONE],
135 net.NextScopedBlob(
'one_minus_screened_teacher_weights'),
138 scaled_true_xent = net.Mul(
139 [true_xent, one_minus_screened_teacher_weights],
140 net.NextScopedBlob(
'scaled_cross_entropy'),
143 scaled_teacher_xent = net.Mul(
144 [teacher_xent, screened_teacher_weights],
145 net.NextScopedBlob(
'scaled_teacher_cross_entropy'),
149 scaled_true_xent = net.Scale(
151 net.NextScopedBlob(
'scaled_cross_entropy'),
154 scaled_teacher_xent = net.Scale(
156 net.NextScopedBlob(
'scaled_teacher_cross_entropy'),
159 if 'weight' in self.input_record.fields:
160 weight_blob = self.input_record.weight()
161 if self.input_record.weight.field_type().base != np.float32:
162 weight_blob = net.Cast(
164 weight_blob +
'_float32',
165 to=core.DataType.FLOAT
167 weight_blob = net.StopGradient(
169 [net.NextScopedBlob(
'weight_stop_gradient')],
171 scaled_true_xent = net.Mul(
172 [scaled_true_xent, weight_blob],
173 net.NextScopedBlob(
'weighted_xent_label'),
175 scaled_teacher_xent = net.Mul(
176 [scaled_teacher_xent, weight_blob],
177 net.NextScopedBlob(
'weighted_xent_teacher'),
180 true_loss = net.AveragedLoss(
182 net.NextScopedBlob(
'true_loss')
184 teacher_loss = net.AveragedLoss(
186 net.NextScopedBlob(
'teacher_loss')
189 [true_loss, teacher_loss],
190 self.output_schema.field_blobs()
_filter_invalid_teacher_label