3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
18 name=
'batch_softmax_loss',
19 label_smoothing_matrix=
None,
23 super(BatchSoftmaxLoss, self).__init__(
24 model, name, input_record, **kwargs)
26 assert schema.is_schema_subset(
45 input_record.prediction.field_type(),
56 def initialize_label_smoothing_constants(self):
60 assert len(self.label_smoothing_matrix.shape) == 2
61 label_dim = self.label_smoothing_matrix.shape[0]
62 assert label_dim == self.label_smoothing_matrix.shape[1]
65 '%s_label_smoothing_matrix' % self.
name,
67 dtype=np.dtype(np.float32),
69 self.
label_dim = self.model.add_global_constant(
70 '%s_label_dim' % self.
name,
72 dtype=np.dtype(np.int64),
78 def compute_smoothed_label(self, net):
80 label = self.input_record.label()
81 original_label_type = self.input_record.label.field_type()
82 if original_label_type.base != np.int64:
83 int64_label = net.NextScopedBlob(
'int64_label')
84 net.Cast([label], [int64_label], to=core.DataType.INT64)
87 one_hot_label = net.NextScopedBlob(
'one_hot_label')
88 smoothed_label = net.NextScopedBlob(
'smoothed_label')
89 net.OneHot([int64_label, self.
label_dim], [one_hot_label])
93 def add_ops(self, net):
94 label = self.input_record.label.field_blobs()
98 if self.input_record.label.field_types()[0].base != np.int32:
101 net.NextScopedBlob(
'int32_label'),
102 to=core.DataType.INT32)
105 softmax_input = self.input_record.prediction.field_blobs() + label
108 weight_blob = self.input_record.weight()
109 if self.input_record.weight.field_type().base != np.float32:
110 weight_blob = net.Cast(
112 weight_blob +
'_float32',
113 to=core.DataType.FLOAT
116 softmax_input += [weight_blob]
120 self.output_schema.field_blobs(),
def get_next_blob_reference(self, name)
def compute_smoothed_label(self, net)
def initialize_label_smoothing_constants(self)