3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
20 subtract_log_odd=
True,
21 name=
'sampling_train',
24 super(SamplingTrain, self).__init__(
25 model, name, input_record, **kwargs
28 layer_class = get_layer_class(prediction_layer)
29 assert issubclass(layer_class, SamplingTrainableMixin)
31 assert 'indices' in input_record
33 "input_record.indices is expected to be a schema.Scalar" 34 assert 'input' in input_record
38 assert 'sampling_prob' in input_record
43 output_dims=output_dims,
47 self._prediction_layer.train_param_blobs = [
48 model.net.NextBlob(str(blob) +
'_sampled')
49 for blob
in self._prediction_layer.param_blobs
52 self.
params = self._prediction_layer.params
56 def add_ops(self, net):
57 self._prediction_layer.add_ops(net)
59 def add_train_ops(self, net):
60 for full_blob, sampled_blob
in zip(
61 self._prediction_layer.param_blobs,
62 self._prediction_layer.train_param_blobs
64 net.Gather([full_blob, self.input_record.indices()], sampled_blob)
65 self._prediction_layer.add_train_ops(net)
68 log_q = net.Log(self.input_record.sampling_prob(),
69 net.NextScopedBlob(
"log_q"))
71 broadcast=1, use_grad_hack=1)