5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 10 template <
typename F,
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
18 col_ids_(this->
template GetRepeatedArgument<int>(
"col_ids")),
20 this->
template GetRepeatedArgument<int>(
"categorical_limits")),
21 vals_(this->
template GetRepeatedArgument<int>(
"vals")) {
22 col_num_ = col_ids_.size();
23 max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
24 CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
25 int expected_vals_size = 0;
26 for (
auto& l : categorical_limits_) {
27 CAFFE_ENFORCE_GT(l, 0);
28 expected_vals_size += l;
30 CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
32 for (
auto& j : col_ids_) {
33 CAFFE_ENFORCE_GE(j, 0);
34 ngram_maps_.push_back(std::map<int, int>());
38 for (
int k = 0; k < col_num_; k++) {
39 int l = categorical_limits_[k];
40 for (
int m = 0; m < l; m++) {
42 ngram_maps_[k][v] = m * base;
48 bool RunOnDevice()
override {
49 auto& floats =
Input(0);
50 auto N = floats.size(0);
51 auto D = floats.size_from_dim(1);
52 const F* floats_data = floats.template data<F>();
54 auto* output = Output(0, {N}, at::dtype<T>());
55 auto* output_data = output->template mutable_data<T>();
56 math::Set<T, Context>(output->numel(), 0, output_data, &context_);
58 CAFFE_ENFORCE_GT(
D, max_col_id_);
59 for (
int i = 0; i < N; i++) {
60 for (
int k = 0; k < col_num_; k++) {
62 int v = round(floats_data[i *
D + j]);
67 output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
76 std::vector<int> col_ids_;
77 std::vector<int> categorical_limits_;
78 std::vector<int> vals_;
79 std::vector<std::map<int, int>> ngram_maps_;
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...