3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
16 Uniform sampling `num_samples - len(input_record)` unique elements from the 17 range [0, num_elements). `samples` is the concatenation of input_record and 18 the samples. input_record is expected to be unique. 27 name=
'uniform_sampling',
30 super(UniformSampling, self).__init__(
31 model, name, input_record, **kwargs
34 assert num_elements > num_samples > 0
39 num_examples_init = (
'GivenTensorInt64Fill',
40 {
'values': [num_samples]})
43 initializer=num_examples_init,
44 optimizer=model.NoOptim)
46 sampling_blob_init = (
'ConstantFill',
47 {
'value': float(num_samples) / num_elements,
48 'dtype': core.DataType.FLOAT})
51 initializer=sampling_blob_init,
52 optimizer=model.NoOptim)
63 def add_ops(self, net):
66 shape = net.Shape([self.
input_record()], net.NextScopedBlob(
"shape"))
68 samples = net.UniqueUniformFill(
70 net.NextScopedBlob(
"samples_before_concat"),
78 [self.output_schema.samples(), net.NextScopedBlob(
"split_info")],
82 self.output_schema.samples(), self.output_schema.samples()
def get_next_blob_reference(self, name)
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)