1 #include "caffe2/operators/weighted_sample_op.h" 6 bool WeightedSampleOp<float, CPUContext>::RunOnDevice() {
10 "The number of tensors of the input and the output must be the same.");
11 auto& weights = Input(0);
12 int batch_size = weights.size(0);
13 int weights_dim = weights.size(1);
15 if (batch_size > 0 && weights_dim > 0) {
16 cum_mass_.resize(weights_dim);
17 const float* mat_weights = weights.template data<float>();
18 const float* mat_values =
nullptr;
19 auto* out_idx = Output(0, {batch_size, 1}, at::dtype<int>());
20 int* output_indices = out_idx->template mutable_data<int>();
21 float* output_values =
nullptr;
23 if (InputSize() == 2) {
24 auto& values = Input(1);
28 "The sampling weights tensor and the sampling values tensor must have the same dimensions.");
29 mat_values = values.template data<float>();
31 auto* out_value = Output(1, {batch_size, 1}, at::dtype<float>());
32 output_values = out_value->template mutable_data<float>();
35 for (
int i = 0; i < batch_size; i++) {
37 int offset = i * weights_dim;
39 cum_mass_[0] = mat_weights[offset];
40 for (
int j = 1; j < weights_dim; j++) {
41 cum_mass_[j] = cum_mass_[j - 1] + mat_weights[offset + j];
44 math::RandUniform<float, CPUContext>(
45 1, 0.0f, cum_mass_[cum_mass_.size() - 1], &r, &context_);
48 cum_mass_[cum_mass_.size() - 1] += 0.01f;
49 auto lb = lower_bound(cum_mass_.begin(), cum_mass_.end(), r);
50 CAFFE_ENFORCE(lb != cum_mass_.end(),
"Cannot find ", r,
" in cum_mass_.");
51 output_indices[i] =
static_cast<int>(lb - cum_mass_.begin());
55 static_cast<float>(mat_values[offset + (lb - cum_mass_.begin())]);
59 auto* out_idx = Output(0, {0}, at::dtype<int>());
60 if (OutputSize() == 2) {
61 auto* out_value = Output(1, {0}, at::dtype<float>());
62 out_value->template mutable_data<float>();
69 REGISTER_CPU_OPERATOR(WeightedSample, WeightedSampleOp<float, CPUContext>);
71 OPERATOR_SCHEMA(WeightedSample)
74 .TensorInferenceFunction([](
const OperatorDef& def,
75 const vector<TensorShape>& in) {
76 vector<TensorShape> out(2);
77 int batch_size = in[0].dims(0);
78 out[0] = CreateTensorShape(vector<int>{batch_size}, TensorProto::INT32);
79 out[1] = CreateTensorShape(vector<int>{batch_size}, TensorProto::FLOAT);
83 The operator performs sampling based on the input sampling weights for 84 each batch. All weights must be non-negative numbers. 85 The input is a 2-D tensor (Tensor) of size (batch_size x weights_dim). 86 For each batch, an index is randomly sampled from the distribution given by 87 the weights of the corresponding batch. 88 The output is a 1-D tensor (Tensor) of size (batch_size x 1) and 89 contains the index(es) of the sampled output. 94 "A 2-D Tensor of size (batch_size x weights_dim)." 95 "All weights must be non-negative numbers.")
99 "An optional 2-D Tensor of size (batch_size x weights_dim)." 100 "Its values correspond to the sampling weights.")
104 "The output tensor contains index(es) sampled from distribution given" 105 "by the weight vector(s) in the input tensor" 106 "The output is a 1-D Tensor of size (batch_size x 1)")
110 "The output tensor contains value(s) selected by the sampled index(es)" 111 "It is a 1-D Tensor of size (batch_size x 1)");
113 SHOULD_NOT_DO_GRADIENT(WeightedSample);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...