Caffe2 - C++ API
A deep learning, cross platform ML framework
ngram_ops.h
1 #pragma once
2 
3 #include <vector>
4 
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 template <typename F, typename T, class Context>
11 class NGramFromCategoricalOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14 
15  template <class... Args>
16  explicit NGramFromCategoricalOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  col_ids_(this->template GetRepeatedArgument<int>("col_ids")),
19  categorical_limits_(
20  this->template GetRepeatedArgument<int>("categorical_limits")),
21  vals_(this->template GetRepeatedArgument<int>("vals")) {
22  col_num_ = col_ids_.size();
23  max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
24  CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
25  int expected_vals_size = 0;
26  for (auto& l : categorical_limits_) {
27  CAFFE_ENFORCE_GT(l, 0);
28  expected_vals_size += l;
29  }
30  CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
31  // compute ngram maps with small end
32  for (auto& j : col_ids_) {
33  CAFFE_ENFORCE_GE(j, 0);
34  ngram_maps_.push_back(std::map<int, int>());
35  }
36  int base = 1;
37  int idx = 0;
38  for (int k = 0; k < col_num_; k++) {
39  int l = categorical_limits_[k];
40  for (int m = 0; m < l; m++) {
41  int v = vals_[idx++];
42  ngram_maps_[k][v] = m * base;
43  }
44  base *= l;
45  }
46  }
47 
48  bool RunOnDevice() override {
49  auto& floats = Input(0);
50  auto N = floats.size(0);
51  auto D = floats.size_from_dim(1);
52  const F* floats_data = floats.template data<F>();
53 
54  auto* output = Output(0, {N}, at::dtype<T>());
55  auto* output_data = output->template mutable_data<T>();
56  math::Set<T, Context>(output->numel(), 0, output_data, &context_);
57 
58  CAFFE_ENFORCE_GT(D, max_col_id_);
59  for (int i = 0; i < N; i++) {
60  for (int k = 0; k < col_num_; k++) {
61  int j = col_ids_[k];
62  int v = round(floats_data[i * D + j]);
63  // for out-of-vocabulary values, we always treat them the same as the
64  // first value specified in vals; if we want to mimic the behavior as
65  // sigrid NGram transform, just push front a random/impossible value at
66  // each segments of vals
67  output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
68  ? 0
69  : ngram_maps_[k][v];
70  }
71  }
72  return true;
73  }
74 
75  private:
76  std::vector<int> col_ids_;
77  std::vector<int> categorical_limits_;
78  std::vector<int> vals_;
79  std::vector<std::map<int, int>> ngram_maps_;
80  int col_num_;
81  int max_col_id_;
82 };
83 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70