Caffe2 - Python API
A deep learning, cross platform ML framework
sampling_train.py
1 ## @package sampling_train
2 # Module caffe2.python.layers.sampling_train
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import ModelLayer, get_layer_class
10 from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
11 
12 
14  def __init__(
15  self,
16  model,
17  input_record,
18  prediction_layer,
19  output_dims,
20  subtract_log_odd=True,
21  name='sampling_train',
22  **kwargs
23  ):
24  super(SamplingTrain, self).__init__(
25  model, name, input_record, **kwargs
26  )
27 
28  layer_class = get_layer_class(prediction_layer)
29  assert issubclass(layer_class, SamplingTrainableMixin)
30 
31  assert 'indices' in input_record
32  assert isinstance(input_record.indices, schema.Scalar),\
33  "input_record.indices is expected to be a schema.Scalar"
34  assert 'input' in input_record
35 
36  self.subtract_log_odd = subtract_log_odd
37  if self.subtract_log_odd:
38  assert 'sampling_prob' in input_record
39 
40  self._prediction_layer = layer_class(
41  model,
42  input_record.input,
43  output_dims=output_dims,
44  **kwargs
45  )
46 
47  self._prediction_layer.train_param_blobs = [
48  model.net.NextBlob(str(blob) + '_sampled')
49  for blob in self._prediction_layer.param_blobs
50  ]
51 
52  self.params = self._prediction_layer.params
53 
54  self.output_schema = self._prediction_layer.output_schema
55 
56  def add_ops(self, net):
57  self._prediction_layer.add_ops(net)
58 
59  def add_train_ops(self, net):
60  for full_blob, sampled_blob in zip(
61  self._prediction_layer.param_blobs,
62  self._prediction_layer.train_param_blobs
63  ):
64  net.Gather([full_blob, self.input_record.indices()], sampled_blob)
65  self._prediction_layer.add_train_ops(net)
66  if not self.subtract_log_odd:
67  return
68  log_q = net.Log(self.input_record.sampling_prob(),
69  net.NextScopedBlob("log_q"))
70  net.Sub([self.output_schema(), log_q], self.output_schema(),
71  broadcast=1, use_grad_hack=1)