Caffe2 - Python API
A deep learning, cross platform ML framework
sampling_train.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package sampling_train
17 # Module caffe2.python.layers.sampling_train
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import schema
24 from caffe2.python.layers.layers import ModelLayer, get_layer_class
25 from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
26 
27 
29  def __init__(
30  self,
31  model,
32  input_record,
33  prediction_layer,
34  output_dims,
35  subtract_log_odd=True,
36  name='sampling_train',
37  **kwargs
38  ):
39  super(SamplingTrain, self).__init__(
40  model, name, input_record, **kwargs
41  )
42 
43  layer_class = get_layer_class(prediction_layer)
44  assert issubclass(layer_class, SamplingTrainableMixin)
45 
46  assert 'indices' in input_record
47  assert isinstance(input_record.indices, schema.Scalar),\
48  "input_record.indices is expected to be a schema.Scalar"
49  assert 'input' in input_record
50 
51  self.subtract_log_odd = subtract_log_odd
52  if self.subtract_log_odd:
53  assert 'sampling_prob' in input_record
54 
55  self._prediction_layer = layer_class(
56  model,
57  input_record.input,
58  output_dims=output_dims,
59  **kwargs
60  )
61 
62  self._prediction_layer.train_param_blobs = [
63  model.net.NextBlob(str(blob) + '_sampled')
64  for blob in self._prediction_layer.param_blobs
65  ]
66 
67  self.params = self._prediction_layer.params
68 
69  self.output_schema = self._prediction_layer.output_schema
70 
71  def add_ops(self, net):
72  self._prediction_layer.add_ops(net)
73 
74  def add_train_ops(self, net):
75  for full_blob, sampled_blob in zip(
76  self._prediction_layer.param_blobs,
77  self._prediction_layer.train_param_blobs
78  ):
79  net.Gather([full_blob, self.input_record.indices()], sampled_blob)
80  self._prediction_layer.add_train_ops(net)
81  if not self.subtract_log_odd:
82  return
83  log_q = net.Log(self.input_record.sampling_prob(),
84  net.NextScopedBlob("log_q"))
85  net.Sub([self.output_schema(), log_q], self.output_schema(),
86  broadcast=1, use_grad_hack=1)