Caffe2 - C++ API
A deep learning, cross platform ML framework
weighted_multi_sampling_op.cc
1 #include "caffe2/operators/weighted_multi_sampling_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 template <class Context>
7 bool WeightedMultiSamplingOp<Context>::RunOnDevice() {
8  const auto& weight = Input(0);
9  CAFFE_ENFORCE_EQ(weight.dim(), 1, "Input should be 1-D vector");
10  auto dims = weight.sizes().vec();
11  size_t data_size = weight.dim32(0);
12 
13  std::vector<int64_t> indices_sizes;
14  auto num_samples = num_samples_;
15  if (InputSize() == 2) {
16  CAFFE_ENFORCE(
17  !OperatorBase::HasArgument("num_samples"),
18  "New shape is specified by the input blob, do not pass in "
19  "the argument `num_samples`.");
20  num_samples = Input(1).numel();
21  indices_sizes = Input(1).sizes().vec();
22  } else {
23  indices_sizes = {num_samples};
24  }
25 
26  auto* indices = Output(0, indices_sizes, at::dtype<int>());
27  int* indices_data = indices->template mutable_data<int>();
28  if (data_size == 0) {
29  indices->Resize(0);
30  return true;
31  }
32 
33  const float* weight_data = weight.template data<float>();
34 
35  for (int i = 0; i < num_samples; ++i) {
36  float r;
37  math::RandUniform<float, Context>(
38  1, 0.0f, weight_data[data_size - 1], &r, &context_);
39  auto lb = std::lower_bound(weight_data, weight_data + data_size, r);
40  CAFFE_ENFORCE(
41  lb != weight_data + data_size, "Cannot find ", r, " in input CDF.");
42  indices_data[i] = static_cast<int>(lb - weight_data);
43  }
44  return true;
45 }
46 
47 REGISTER_CPU_OPERATOR(
48  WeightedMultiSampling,
49  WeightedMultiSamplingOp<CPUContext>);
50 
51 OPERATOR_SCHEMA(WeightedMultiSampling)
52  .NumInputs(1, 2)
53  .NumOutputs(1)
54  .TensorInferenceFunction([](const OperatorDef& def,
55  const vector<TensorShape>& in) {
56  vector<TensorShape> out(1);
57  if (in[0].dims(0) == 0) {
58  out[0].set_data_type(TensorProto::INT32);
59  out[0].add_dims(0);
60  return out;
61  }
62 
63  const ArgumentHelper args(def);
64  if (args.HasArgument("num_samples")) {
65  CAFFE_ENFORCE_EQ(
66  in.size(),
67  1,
68  "New shape must not be specified by the input blob and the "
69  "argument `num_samples` at the same time.");
70  int num_samples = args.GetSingleArgument<int64_t>("num_samples", 0);
71  out[0] =
72  CreateTensorShape(vector<int64_t>{num_samples}, TensorProto::INT32);
73  return out;
74  } else {
75  CAFFE_ENFORCE_EQ(
76  in.size(),
77  2,
78  "New shape must be specified by either the input blob or the "
79  "argument `num_samples`.");
80  std::vector<int64_t> output_dims = GetDimsVector(in[1]);
81  out[0] = CreateTensorShape(output_dims, TensorProto::INT32);
82  return out;
83  }
84  })
85  .SetDoc(R"DOC(
86 The operator performs sampling based on the input sampling weights.
87 All weights are cummulative probability thus sorted. The output is
88 a 1-D tensor (Tensor). If two inputs are given, the second input
89 is used to provide shape of the output sample tensor. Otherwise, we use
90 argument `num_samples` to determine the number of samples to generate.
91 )DOC")
92  .Input(
93  0,
94  "sampling_cdf",
95  "An optional 1-D Tensor."
96  "Input cumulative sampling probability (such as [0.2, 0.5, 0.8, 1.5])."
97  " All weights must be non-negative numbers. Note that the last value of"
98  " CDF is not necessary 1. If the last value is not 1, all values in"
99  " sampling_cdf will be scaled by this number.")
100  .Input(
101  1,
102  "shape_tensor (optional)",
103  "Tensor whose shape will be applied to output.")
104  .Output(
105  0,
106  "sampled_indexes",
107  "The output tensor contains indices sampled from distribution given"
108  "by the weight vector in the input tensor"
109  "The output is a 1-D Tensor of size determined by argument"
110  "`num_samples` or the second input tensor.")
111  .Arg("num_samples", "number of samples to sample from the input data");
112 
113 SHOULD_NOT_DO_GRADIENT(WeightedMultiSample);
114 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70