Caffe2 - Python API
A deep learning, cross platform ML framework
batch_lr_loss.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package batch_lr_loss
17 # Module caffe2.python.layers.batch_lr_loss
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 from caffe2.python.layers.tags import (
28  Tags
29 )
30 import numpy as np
31 
32 
33 class BatchLRLoss(ModelLayer):
34 
35  def __init__(self, model, input_record, name='batch_lr_loss',
36  average_loss=True, **kwargs):
37  super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
38 
39  self.average_loss = average_loss
40 
41  assert (schema.is_schema_subset(
43  ('label', schema.Scalar()),
44  ('logit', schema.Scalar())
45  ),
46  input_record
47  ))
48 
49  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
50 
52  np.float32,
53  self.get_next_blob_reference('output')
54  )
55 
56  def add_ops(self, net):
57  # numerically stable log-softmax with crossentropy
58  label = self.input_record.label()
59  # mandatory cast to float32
60  # self.input_record.label.field_type().base is np.float32 but
61  # label type is actually int
62  label = net.Cast(
63  label,
64  net.NextScopedBlob('label_float32'),
65  to=core.DataType.FLOAT)
66  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
67  dims=[1])
68  xent = net.SigmoidCrossEntropyWithLogits(
69  [self.input_record.logit(), label],
70  net.NextScopedBlob('cross_entropy'),
71  )
72 
73  if 'weight' in self.input_record.fields:
74  weight_blob = self.input_record.weight()
75  if self.input_record.weight.field_type().base != np.float32:
76  weight_blob = net.Cast(
77  weight_blob,
78  weight_blob + '_float32',
79  to=core.DataType.FLOAT
80  )
81  weight_blob = net.StopGradient(
82  [weight_blob],
83  [net.NextScopedBlob('weight_stop_gradient')],
84  )
85  xent = net.Mul(
86  [xent, weight_blob],
87  net.NextScopedBlob('weighted_cross_entropy'),
88  )
89 
90  if self.average_loss:
91  net.AveragedLoss(xent, self.output_schema.field_blobs())
92  else:
93  net.ReduceFrontSum(xent, self.output_schema.field_blobs())