Caffe2 - Python API
A deep learning, cross platform ML framework
batch_softmax_loss.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package batch_softmax_loss
17 # Module caffe2.python.layers.batch_softmax_loss
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import ModelLayer
25 import numpy as np
26 
27 
29  def __init__(
30  self,
31  model,
32  input_record,
33  name='batch_softmax_loss',
34  **kwargs
35  ):
36  super(BatchSoftmaxLoss, self).__init__(
37  model, name, input_record, **kwargs)
38 
39  assert schema.is_schema_subset(
41  ('label', schema.Scalar()),
42  ('prediction', schema.Scalar()),
43  ),
44  input_record
45  )
46 
48  (
49  'softmax', schema.Scalar(
50  input_record.prediction.field_type(),
51  self.get_next_blob_reference('softmax')
52  )
53  ),
54  (
55  'loss', schema.Scalar(
56  np.float32, self.get_next_blob_reference('loss')
57  )
58  ),
59  )
60 
61  def add_ops(self, net):
62  label = self.input_record.label.field_blobs()
63  if self.input_record.label.field_types()[0].base != np.int32:
64  label = [
65  net.Cast(label,
66  net.NextScopedBlob('int32_label'),
67  to=core.DataType.INT32)
68  ]
69 
70  softmax_input = self.input_record.prediction.field_blobs() + label
71 
72  if 'weight' in self.input_record:
73  weight_blob = self.input_record.weight()
74  if self.input_record.weight.field_type().base != np.float32:
75  weight_blob = net.Cast(
76  weight_blob,
77  weight_blob + '_float32',
78  to=core.DataType.FLOAT
79  )
80 
81  softmax_input += [weight_blob]
82 
83  net.SoftmaxWithLoss(
84  softmax_input,
85  self.output_schema.field_blobs()
86  )
def get_next_blob_reference(self, name)
Definition: layers.py:352