Caffe2 - C++ API
A deep learning, cross platform ML framework
weighted_sample_op.cc
1 #include "caffe2/operators/weighted_sample_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool WeightedSampleOp<float, CPUContext>::RunOnDevice() {
7  CAFFE_ENFORCE_EQ(
8  InputSize(),
9  OutputSize(),
10  "The number of tensors of the input and the output must be the same.");
11  auto& weights = Input(0);
12  int batch_size = weights.size(0);
13  int weights_dim = weights.size(1);
14 
15  if (batch_size > 0 && weights_dim > 0) {
16  cum_mass_.resize(weights_dim);
17  const float* mat_weights = weights.template data<float>();
18  const float* mat_values = nullptr;
19  auto* out_idx = Output(0, {batch_size, 1}, at::dtype<int>());
20  int* output_indices = out_idx->template mutable_data<int>();
21  float* output_values = nullptr;
22 
23  if (InputSize() == 2) {
24  auto& values = Input(1);
25  CAFFE_ENFORCE_EQ(
26  weights.sizes(),
27  values.sizes(),
28  "The sampling weights tensor and the sampling values tensor must have the same dimensions.");
29  mat_values = values.template data<float>();
30 
31  auto* out_value = Output(1, {batch_size, 1}, at::dtype<float>());
32  output_values = out_value->template mutable_data<float>();
33  }
34 
35  for (int i = 0; i < batch_size; i++) {
36  float r;
37  int offset = i * weights_dim;
38 
39  cum_mass_[0] = mat_weights[offset];
40  for (int j = 1; j < weights_dim; j++) {
41  cum_mass_[j] = cum_mass_[j - 1] + mat_weights[offset + j];
42  }
43 
44  math::RandUniform<float, CPUContext>(
45  1, 0.0f, cum_mass_[cum_mass_.size() - 1], &r, &context_);
46  // Makes the element in cum_mass_ slightly bigger
47  // to compensate inaccuracy introduced due to rounding,
48  cum_mass_[cum_mass_.size() - 1] += 0.01f;
49  auto lb = lower_bound(cum_mass_.begin(), cum_mass_.end(), r);
50  CAFFE_ENFORCE(lb != cum_mass_.end(), "Cannot find ", r, " in cum_mass_.");
51  output_indices[i] = static_cast<int>(lb - cum_mass_.begin());
52 
53  if (output_values) {
54  output_values[i] =
55  static_cast<float>(mat_values[offset + (lb - cum_mass_.begin())]);
56  }
57  }
58  } else {
59  auto* out_idx = Output(0, {0}, at::dtype<int>());
60  if (OutputSize() == 2) {
61  auto* out_value = Output(1, {0}, at::dtype<float>());
62  out_value->template mutable_data<float>();
63  }
64  }
65 
66  return true;
67 }
68 
69 REGISTER_CPU_OPERATOR(WeightedSample, WeightedSampleOp<float, CPUContext>);
70 
71 OPERATOR_SCHEMA(WeightedSample)
72  .NumInputs(1, 2)
73  .NumOutputs(1, 2)
74  .TensorInferenceFunction([](const OperatorDef& def,
75  const vector<TensorShape>& in) {
76  vector<TensorShape> out(2);
77  int batch_size = in[0].dims(0);
78  out[0] = CreateTensorShape(vector<int>{batch_size}, TensorProto::INT32);
79  out[1] = CreateTensorShape(vector<int>{batch_size}, TensorProto::FLOAT);
80  return out;
81  })
82  .SetDoc(R"DOC(
83 The operator performs sampling based on the input sampling weights for
84 each batch. All weights must be non-negative numbers.
85 The input is a 2-D tensor (Tensor) of size (batch_size x weights_dim).
86 For each batch, an index is randomly sampled from the distribution given by
87 the weights of the corresponding batch.
88 The output is a 1-D tensor (Tensor) of size (batch_size x 1) and
89 contains the index(es) of the sampled output.
90 )DOC")
91  .Input(
92  0,
93  "sampling_weights",
94  "A 2-D Tensor of size (batch_size x weights_dim)."
95  "All weights must be non-negative numbers.")
96  .Input(
97  1,
98  "sampling_values",
99  "An optional 2-D Tensor of size (batch_size x weights_dim)."
100  "Its values correspond to the sampling weights.")
101  .Output(
102  0,
103  "sampled_indexes",
104  "The output tensor contains index(es) sampled from distribution given"
105  "by the weight vector(s) in the input tensor"
106  "The output is a 1-D Tensor of size (batch_size x 1)")
107  .Output(
108  1,
109  "sampled_values",
110  "The output tensor contains value(s) selected by the sampled index(es)"
111  "It is a 1-D Tensor of size (batch_size x 1)");
112 
113 SHOULD_NOT_DO_GRADIENT(WeightedSample);
114 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13