Caffe2 - C++ API
A deep learning, cross platform ML framework
one_hot_ops.h
1 #ifndef CAFFE_OPERATORS_ONE_HOT_OPS_H_
2 #define CAFFE_OPERATORS_ONE_HOT_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class OneHotOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15 
16  template <class... Args>
17  explicit OneHotOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...) {}
19 
20  bool RunOnDevice() override {
21  auto& indices = Input(0);
22  CAFFE_ENFORCE_EQ(
23  indices.dim(),
24  1,
25  "indices input must be 1D tensor of data type int64_t");
26 
27  // Index size input must be in CPU context
28  auto& index_size_tensor = this->template Input<Tensor>(1, CPU);
29  CAFFE_ENFORCE_EQ(
30  index_size_tensor.numel(),
31  1,
32  "index_size_tensor input must be scalar of data type int64_t");
33 
34  auto batch_size = indices.numel();
35  auto index_size = *index_size_tensor.template data<int64_t>();
36  auto one_hots = Output(0);
37  one_hots->Resize(batch_size, index_size);
38  auto output_size = one_hots->numel();
39  if (output_size == 0) {
40  return true;
41  }
42 
43  DoOneHotOp(batch_size, index_size, indices, one_hots);
44  return true;
45  }
46 
47  protected:
48  void DoOneHotOp(
49  int64_t batch_size,
50  int64_t index_size,
51  const Tensor& indices,
52  Tensor* output);
53 };
54 
55 template <class Context>
56 class BatchOneHotOp final : public Operator<Context> {
57  public:
58  USE_OPERATOR_CONTEXT_FUNCTIONS;
59  template <class... Args>
60  explicit BatchOneHotOp(Args&&... args)
61  : Operator<Context>(std::forward<Args>(args)...) {}
62 
63  bool RunOnDevice() override {
65  }
66 
67  template <typename T>
68  bool DoRunWithType();
69 
70  INPUT_TAGS(X, LENS, VALS);
71 
72  protected:
73  OUTPUT_TAGS(ONE_HOT);
74 
75  private:
76  // allows for fast random access to a given dict and is re-used across runs
77  std::vector<int64_t> valsOffsets_;
78 };
79 
80 template <class Context>
81 class BatchBucketOneHotOp final : public Operator<Context> {
82  public:
83  USE_OPERATOR_CONTEXT_FUNCTIONS;
84  template <class... Args>
85  explicit BatchBucketOneHotOp(Args&&... args)
86  : Operator<Context>(std::forward<Args>(args)...) {}
87 
88  bool RunOnDevice() override;
89 
90  protected:
91  INPUT_TAGS(X, LENS, BOUNDARIES);
92  OUTPUT_TAGS(ONE_HOT);
93 };
94 
95 } // namespace caffe2
96 
97 #endif // CAFFE_OPERATORS_ONE_HOT_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13