3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
19 name=
'batch_sigmoid_cross_entropy_loss',
22 super(BatchSigmoidCrossEntropyLoss, self).__init__(
23 model, name, input_record, **kwargs)
25 assert schema.is_schema_subset(
32 assert input_record.prediction.field_type().shape == \
33 input_record.label.field_type().shape, \
34 "prediction and label must have the same shape" 36 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
42 def add_ops(self, net):
43 sigmoid_cross_entropy = net.SigmoidCrossEntropyWithLogits(
44 [self.input_record.prediction(), self.input_record.label()],
45 net.NextScopedBlob(
'sigmoid_cross_entropy')
49 sigmoid_cross_entropy, self.output_schema.field_blobs())
def get_next_blob_reference(self, name)