1 #include "caffe2/operators/weighted_multi_sampling_op.h" 2 #include "caffe2/utils/math.h" 6 template <
class Context>
7 bool WeightedMultiSamplingOp<Context>::RunOnDevice() {
8 const auto& weight = Input(0);
9 CAFFE_ENFORCE_EQ(weight.dim(), 1,
"Input should be 1-D vector");
10 auto dims = weight.sizes().vec();
11 size_t data_size = weight.dim32(0);
13 std::vector<int64_t> indices_sizes;
14 auto num_samples = num_samples_;
15 if (InputSize() == 2) {
18 "New shape is specified by the input blob, do not pass in " 19 "the argument `num_samples`.");
20 num_samples = Input(1).numel();
21 indices_sizes = Input(1).sizes().vec();
23 indices_sizes = {num_samples};
26 auto* indices = Output(0, indices_sizes, at::dtype<int>());
27 int* indices_data = indices->template mutable_data<int>();
33 const float* weight_data = weight.template data<float>();
35 for (
int i = 0; i < num_samples; ++i) {
37 math::RandUniform<float, Context>(
38 1, 0.0f, weight_data[data_size - 1], &r, &context_);
39 auto lb = std::lower_bound(weight_data, weight_data + data_size, r);
41 lb != weight_data + data_size,
"Cannot find ", r,
" in input CDF.");
42 indices_data[i] =
static_cast<int>(lb - weight_data);
47 REGISTER_CPU_OPERATOR(
48 WeightedMultiSampling,
49 WeightedMultiSamplingOp<CPUContext>);
51 OPERATOR_SCHEMA(WeightedMultiSampling)
54 .TensorInferenceFunction([](
const OperatorDef& def,
55 const vector<TensorShape>& in) {
56 vector<TensorShape> out(1);
57 if (in[0].dims(0) == 0) {
58 out[0].set_data_type(TensorProto::INT32);
63 const ArgumentHelper args(def);
64 if (args.HasArgument(
"num_samples")) {
68 "New shape must not be specified by the input blob and the " 69 "argument `num_samples` at the same time.");
70 int num_samples = args.GetSingleArgument<int64_t>(
"num_samples", 0);
72 CreateTensorShape(vector<int64_t>{num_samples}, TensorProto::INT32);
78 "New shape must be specified by either the input blob or the " 79 "argument `num_samples`.");
80 std::vector<int64_t> output_dims = GetDimsVector(in[1]);
81 out[0] = CreateTensorShape(output_dims, TensorProto::INT32);
86 The operator performs sampling based on the input sampling weights. 87 All weights are cummulative probability thus sorted. The output is 88 a 1-D tensor (Tensor). If two inputs are given, the second input 89 is used to provide shape of the output sample tensor. Otherwise, we use 90 argument `num_samples` to determine the number of samples to generate. 95 "An optional 1-D Tensor." 96 "Input cumulative sampling probability (such as [0.2, 0.5, 0.8, 1.5])." 97 " All weights must be non-negative numbers. Note that the last value of" 98 " CDF is not necessary 1. If the last value is not 1, all values in" 99 " sampling_cdf will be scaled by this number.")
102 "shape_tensor (optional)",
103 "Tensor whose shape will be applied to output.")
107 "The output tensor contains indices sampled from distribution given" 108 "by the weight vector in the input tensor" 109 "The output is a 1-D Tensor of size determined by argument" 110 "`num_samples` or the second input tensor.")
111 .Arg(
"num_samples",
"number of samples to sample from the input data");
113 SHOULD_NOT_DO_GRADIENT(WeightedMultiSample);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.