Caffe2 - C++ API
A deep learning, cross platform ML framework
weighted_sample_op.cc
1 #include "caffe2/operators/weighted_sample_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool WeightedSampleOp<float, CPUContext>::RunOnDevice() {
7  CAFFE_ENFORCE_EQ(
8  InputSize(),
9  OutputSize(),
10  "The number of tensors of the input and the output must be the same.");
11  auto& weights = Input(0);
12  int batch_size = weights.dim(0);
13  int weights_dim = weights.dim(1);
14  auto* out_idx = Output(0);
15 
16  if (batch_size > 0 && weights_dim > 0) {
17  cum_mass_.resize(weights_dim);
18  const float* mat_weights = weights.template data<float>();
19  const float* mat_values = nullptr;
20  out_idx->Resize(batch_size, 1);
21  int* output_indices = out_idx->template mutable_data<int>();
22  float* output_values = nullptr;
23 
24  if (InputSize() == 2) {
25  auto& values = Input(1);
26  CAFFE_ENFORCE_EQ(
27  weights.dims(),
28  values.dims(),
29  "The sampling weights tensor and the sampling values tensor must have the same dimensions.");
30  mat_values = values.template data<float>();
31  auto* out_value = Output(1);
32  out_value->Resize(batch_size, 1);
33  output_values = out_value->template mutable_data<float>();
34  }
35 
36  for (int i = 0; i < batch_size; i++) {
37  float r;
38  int offset = i * weights_dim;
39 
40  cum_mass_[0] = mat_weights[offset];
41  for (int j = 1; j < weights_dim; j++) {
42  cum_mass_[j] = cum_mass_[j - 1] + mat_weights[offset + j];
43  }
44 
45  math::RandUniform<float, CPUContext>(
46  1, 0.0f, cum_mass_[cum_mass_.size() - 1], &r, &context_);
47  // Makes the element in cum_mass_ slightly bigger
48  // to compensate inaccuracy introduced due to rounding,
49  cum_mass_[cum_mass_.size() - 1] += 0.01f;
50  auto lb = lower_bound(cum_mass_.begin(), cum_mass_.end(), r);
51  CAFFE_ENFORCE(lb != cum_mass_.end(), "Cannot find ", r, " in cum_mass_.");
52  output_indices[i] = static_cast<int>(lb - cum_mass_.begin());
53 
54  if (output_values) {
55  output_values[i] =
56  static_cast<float>(mat_values[offset + (lb - cum_mass_.begin())]);
57  }
58  }
59  } else {
60  out_idx->Resize(0);
61  out_idx->template mutable_data<int>();
62  if (OutputSize() == 2) {
63  auto* out_value = Output(1);
64  out_value->Resize(0);
65  out_value->template mutable_data<float>();
66  }
67  }
68 
69  return true;
70 }
71 
72 REGISTER_CPU_OPERATOR(WeightedSample, WeightedSampleOp<float, CPUContext>);
73 
74 OPERATOR_SCHEMA(WeightedSample)
75  .NumInputs(1, 2)
76  .NumOutputs(1, 2)
77  .TensorInferenceFunction([](const OperatorDef& def,
78  const vector<TensorShape>& in) {
79  vector<TensorShape> out(2);
80  int batch_size = in[0].dims(0);
81  out[0] = CreateTensorShape(vector<int>{batch_size}, TensorProto::INT32);
82  out[1] = CreateTensorShape(vector<int>{batch_size}, TensorProto::FLOAT);
83  return out;
84  })
85  .SetDoc(R"DOC(
86 The operator performs sampling based on the input sampling weights for
87 each batch. All weights must be non-negative numbers.
88 The input is a 2-D tensor (Tensor<float>) of size (batch_size x weights_dim).
89 For each batch, an index is randomly sampled from the distribution given by
90 the weights of the corresponding batch.
91 The output is a 1-D tensor (Tensor<int>) of size (batch_size x 1) and
92 contains the index(es) of the sampled output.
93 )DOC")
94  .Input(
95  0,
96  "sampling_weights",
97  "A 2-D Tensor<float> of size (batch_size x weights_dim)."
98  "All weights must be non-negative numbers.")
99  .Input(
100  1,
101  "sampling_values",
102  "An optional 2-D Tensor<float> of size (batch_size x weights_dim)."
103  "Its values correspond to the sampling weights.")
104  .Output(
105  0,
106  "sampled_indexes",
107  "The output tensor contains index(es) sampled from distribution given"
108  "by the weight vector(s) in the input tensor"
109  "The output is a 1-D Tensor<int> of size (batch_size x 1)")
110  .Output(
111  1,
112  "sampled_values",
113  "The output tensor contains value(s) selected by the sampled index(es)"
114  "It is a 1-D Tensor<float> of size (batch_size x 1)");
115 
116 SHOULD_NOT_DO_GRADIENT(WeightedSample);
117 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.