Caffe2 - Python API
A deep learning, cross platform ML framework
batch_lr_loss.py
1 ## @package batch_lr_loss
2 # Module caffe2.python.layers.batch_lr_loss
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchLRLoss(ModelLayer):
19 
20  def __init__(
21  self,
22  model,
23  input_record,
24  name='batch_lr_loss',
25  average_loss=True,
26  jsd_weight=0.0,
27  pos_label_target=1.0,
28  neg_label_target=0.0,
29  homotopy_weighting=False,
30  log_D_trick=False,
31  unjoined_lr_loss=False,
32  **kwargs
33  ):
34  super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
35 
36  self.average_loss = average_loss
37 
38  assert (schema.is_schema_subset(
40  ('label', schema.Scalar()),
41  ('logit', schema.Scalar())
42  ),
43  input_record
44  ))
45 
46  self.jsd_fuse = False
47  assert jsd_weight >= 0 and jsd_weight <= 1
48  if jsd_weight > 0 or homotopy_weighting:
49  assert 'prediction' in input_record
50  self.init_weight(jsd_weight, homotopy_weighting)
51  self.jsd_fuse = True
52  self.homotopy_weighting = homotopy_weighting
53 
54  assert pos_label_target <= 1 and pos_label_target >= 0
55  assert neg_label_target <= 1 and neg_label_target >= 0
56  assert pos_label_target >= neg_label_target
57  self.pos_label_target = pos_label_target
58  self.neg_label_target = neg_label_target
59 
60  assert not (log_D_trick and unjoined_lr_loss)
61  self.log_D_trick = log_D_trick
62  self.unjoined_lr_loss = unjoined_lr_loss
63 
64  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
65 
67  np.float32,
68  self.get_next_blob_reference('output')
69  )
70 
71  def init_weight(self, jsd_weight, homotopy_weighting):
72  if homotopy_weighting:
73  self.mutex = self.create_param(
74  param_name=('%s_mutex' % self.name),
75  shape=None,
76  initializer=('CreateMutex', ),
77  optimizer=self.model.NoOptim,
78  )
79  self.counter = self.create_param(
80  param_name=('%s_counter' % self.name),
81  shape=[1],
82  initializer=(
83  'ConstantFill', {
84  'value': 0,
85  'dtype': core.DataType.INT64
86  }
87  ),
88  optimizer=self.model.NoOptim,
89  )
90  self.xent_weight = self.create_param(
91  param_name=('%s_xent_weight' % self.name),
92  shape=[1],
93  initializer=(
94  'ConstantFill', {
95  'value': 1.,
96  'dtype': core.DataType.FLOAT
97  }
98  ),
99  optimizer=self.model.NoOptim,
100  )
101  self.jsd_weight = self.create_param(
102  param_name=('%s_jsd_weight' % self.name),
103  shape=[1],
104  initializer=(
105  'ConstantFill', {
106  'value': 0.,
107  'dtype': core.DataType.FLOAT
108  }
109  ),
110  optimizer=self.model.NoOptim,
111  )
112  else:
113  self.jsd_weight = self.model.add_global_constant(
114  '%s_jsd_weight' % self.name, jsd_weight
115  )
116  self.xent_weight = self.model.add_global_constant(
117  '%s_xent_weight' % self.name, 1. - jsd_weight
118  )
119 
120  def update_weight(self, net):
121  net.AtomicIter([self.mutex, self.counter], [self.counter])
122  # iter = 0: lr = 1;
123  # iter = 1e6; lr = 0.5^0.1 = 0.93
124  # iter = 1e9; lr = 1e-3^0.1 = 0.50
125  net.LearningRate([self.counter], [self.xent_weight], base_lr=1.0,
126  policy='inv', gamma=1e-6, power=0.1,)
127  net.Sub(
128  [self.model.global_constants['ONE'], self.xent_weight],
129  [self.jsd_weight]
130  )
131  return self.xent_weight, self.jsd_weight
132 
133  def add_ops(self, net):
134  # numerically stable log-softmax with crossentropy
135  label = self.input_record.label()
136  # mandatory cast to float32
137  # self.input_record.label.field_type().base is np.float32 but
138  # label type is actually int
139  label = net.Cast(
140  label,
141  net.NextScopedBlob('label_float32'),
142  to=core.DataType.FLOAT)
143  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
144  dims=[1])
145  if self.pos_label_target != 1.0 or self.neg_label_target != 0.0:
146  label = net.StumpFunc(
147  label,
148  net.NextScopedBlob('smoothed_label'),
149  threshold=0.5,
150  low_value=self.neg_label_target,
151  high_value=self.pos_label_target,
152  )
153  xent = net.SigmoidCrossEntropyWithLogits(
154  [self.input_record.logit(), label],
155  net.NextScopedBlob('cross_entropy'),
156  log_D_trick=self.log_D_trick,
157  unjoined_lr_loss=self.unjoined_lr_loss
158  )
159  # fuse with JSD
160  if self.jsd_fuse:
161  jsd = net.BernoulliJSD(
162  [self.input_record.prediction(), label],
163  net.NextScopedBlob('jsd'),
164  )
165  if self.homotopy_weighting:
166  self.update_weight(net)
167  loss = net.WeightedSum(
168  [xent, self.xent_weight, jsd, self.jsd_weight],
169  net.NextScopedBlob('loss'),
170  )
171  else:
172  loss = xent
173  if 'weight' in self.input_record.fields:
174  weight_blob = self.input_record.weight()
175  if self.input_record.weight.field_type().base != np.float32:
176  weight_blob = net.Cast(
177  weight_blob,
178  weight_blob + '_float32',
179  to=core.DataType.FLOAT
180  )
181  weight_blob = net.StopGradient(
182  [weight_blob],
183  [net.NextScopedBlob('weight_stop_gradient')],
184  )
185  loss = net.Mul(
186  [loss, weight_blob],
187  net.NextScopedBlob('weighted_cross_entropy'),
188  )
189 
190  if self.average_loss:
191  net.AveragedLoss(loss, self.output_schema.field_blobs())
192  else:
193  net.ReduceFrontSum(loss, self.output_schema.field_blobs())
def init_weight(self, jsd_weight, homotopy_weighting)