Caffe2 - C++ API
A deep learning, cross platform ML framework
ngram_ops.h
1 
17 #pragma once
18 
19 #include <vector>
20 
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 template <typename F, typename T, class Context>
27 class NGramFromCategoricalOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30 
31  NGramFromCategoricalOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  col_ids_(OperatorBase::GetRepeatedArgument<int>("col_ids")),
34  categorical_limits_(
35  OperatorBase::GetRepeatedArgument<int>("categorical_limits")),
36  vals_(OperatorBase::GetRepeatedArgument<int>("vals")) {
37  col_num_ = col_ids_.size();
38  max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
39  CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
40  int expected_vals_size = 0;
41  for (auto& l : categorical_limits_) {
42  CAFFE_ENFORCE_GT(l, 0);
43  expected_vals_size += l;
44  }
45  CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
46  // compute ngram maps with small end
47  for (auto& j : col_ids_) {
48  CAFFE_ENFORCE_GE(j, 0);
49  ngram_maps_.push_back(std::map<int, int>());
50  }
51  int base = 1;
52  int idx = 0;
53  for (int k = 0; k < col_num_; k++) {
54  int l = categorical_limits_[k];
55  for (int m = 0; m < l; m++) {
56  int v = vals_[idx++];
57  ngram_maps_[k][v] = m * base;
58  }
59  base *= l;
60  }
61  }
62 
63  bool RunOnDevice() override {
64  auto& floats = Input(0);
65  auto N = floats.dim(0);
66  auto D = floats.size_from_dim(1);
67  const F* floats_data = floats.template data<F>();
68  auto* output = Output(0);
69  output->Resize(N);
70  auto* output_data = output->template mutable_data<T>();
71  math::Set<T, Context>(output->size(), 0, output_data, &context_);
72 
73  CAFFE_ENFORCE_GT(D, max_col_id_);
74  for (int i = 0; i < N; i++) {
75  for (int k = 0; k < col_num_; k++) {
76  int j = col_ids_[k];
77  int v = round(floats_data[i * D + j]);
78  // for out-of-vocabulary values, we always treat them the same as the
79  // first value specified in vals; if we want to mimic the behavior as
80  // sigrid NGram transform, just push front a random/impossible value at
81  // each segments of vals
82  output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
83  ? 0
84  : ngram_maps_[k][v];
85  }
86  }
87  return true;
88  }
89 
90  private:
91  std::vector<int> col_ids_;
92  std::vector<int> categorical_limits_;
93  std::vector<int> vals_;
94  std::vector<std::map<int, int>> ngram_maps_;
95  int col_num_;
96  int max_col_id_;
97 };
98 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.