3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
29 homotopy_weighting=
False,
31 unjoined_lr_loss=
False,
34 super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
38 assert (schema.is_schema_subset(
47 assert jsd_weight >= 0
and jsd_weight <= 1
48 if jsd_weight > 0
or homotopy_weighting:
49 assert 'prediction' in input_record
54 assert pos_label_target <= 1
and pos_label_target >= 0
55 assert neg_label_target <= 1
and neg_label_target >= 0
56 assert pos_label_target >= neg_label_target
60 assert not (log_D_trick
and unjoined_lr_loss)
64 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
68 self.get_next_blob_reference(
'output')
71 def init_weight(self, jsd_weight, homotopy_weighting):
72 if homotopy_weighting:
73 self.
mutex = self.create_param(
74 param_name=(
'%s_mutex' % self.name),
76 initializer=(
'CreateMutex', ),
77 optimizer=self.model.NoOptim,
79 self.
counter = self.create_param(
80 param_name=(
'%s_counter' % self.name),
85 'dtype': core.DataType.INT64
88 optimizer=self.model.NoOptim,
91 param_name=(
'%s_xent_weight' % self.name),
96 'dtype': core.DataType.FLOAT
99 optimizer=self.model.NoOptim,
102 param_name=(
'%s_jsd_weight' % self.name),
107 'dtype': core.DataType.FLOAT
110 optimizer=self.model.NoOptim,
113 self.
jsd_weight = self.model.add_global_constant(
114 '%s_jsd_weight' % self.name, jsd_weight
117 '%s_xent_weight' % self.name, 1. - jsd_weight
120 def update_weight(self, net):
126 policy=
'inv', gamma=1e-6, power=0.1,)
128 [self.model.global_constants[
'ONE'], self.
xent_weight],
133 def add_ops(self, net):
135 label = self.input_record.label()
141 net.NextScopedBlob(
'label_float32'),
142 to=core.DataType.FLOAT)
143 label = net.ExpandDims(label, net.NextScopedBlob(
'expanded_label'),
146 label = net.StumpFunc(
148 net.NextScopedBlob(
'smoothed_label'),
153 xent = net.SigmoidCrossEntropyWithLogits(
154 [self.input_record.logit(), label],
155 net.NextScopedBlob(
'cross_entropy'),
161 jsd = net.BernoulliJSD(
162 [self.input_record.prediction(), label],
163 net.NextScopedBlob(
'jsd'),
167 loss = net.WeightedSum(
169 net.NextScopedBlob(
'loss'),
173 if 'weight' in self.input_record.fields:
174 weight_blob = self.input_record.weight()
175 if self.input_record.weight.field_type().base != np.float32:
176 weight_blob = net.Cast(
178 weight_blob +
'_float32',
179 to=core.DataType.FLOAT
181 weight_blob = net.StopGradient(
183 [net.NextScopedBlob(
'weight_stop_gradient')],
187 net.NextScopedBlob(
'weighted_cross_entropy'),
191 net.AveragedLoss(loss, self.output_schema.field_blobs())
193 net.ReduceFrontSum(loss, self.output_schema.field_blobs())
def init_weight(self, jsd_weight, homotopy_weighting)
def update_weight(self, net)