Caffe2 - C++ API
A deep learning, cross platform ML framework
one_hot_ops.h
1 
17 #ifndef CAFFE_OPERATORS_ONE_HOT_OPS_H_
18 #define CAFFE_OPERATORS_ONE_HOT_OPS_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class OneHotOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31 
32  OneHotOp(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws) {}
34 
35  bool RunOnDevice() override {
36  auto& indices = Input(0);
37  CAFFE_ENFORCE_EQ(
38  indices.ndim(),
39  1,
40  "indices input must be 1D tensor of data type TIndex");
41 
42  // Index size input must be in CPU context
43  auto& index_size_tensor = OperatorBase::Input<Tensor<CPUContext>>(1);
44  CAFFE_ENFORCE_EQ(
45  index_size_tensor.size(),
46  1,
47  "index_size_tensor input must be scalar of data type TIndex");
48 
49  auto batch_size = indices.size();
50  auto index_size = *index_size_tensor.template data<TIndex>();
51  auto one_hots = Output(0);
52  one_hots->Resize(batch_size, index_size);
53  auto output_size = one_hots->size();
54  if (output_size == 0) {
55  return true;
56  }
57 
58  DoOneHotOp(batch_size, index_size, indices, one_hots);
59  return true;
60  }
61 
62  protected:
63  void DoOneHotOp(
64  TIndex batch_size,
65  TIndex index_size,
66  const Tensor<Context>& indices,
67  Tensor<Context>* output);
68 };
69 
70 template <class Context>
71 class BatchOneHotOp final : public Operator<Context> {
72  public:
73  USE_OPERATOR_CONTEXT_FUNCTIONS;
74  BatchOneHotOp(const OperatorDef& operator_def, Workspace* ws)
75  : Operator<Context>(operator_def, ws) {}
76 
77  bool RunOnDevice() override {
78  return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(X));
79  }
80 
81  template <typename T>
82  bool DoRunWithType();
83 
84  protected:
85  INPUT_TAGS(X, LENS, VALS);
86  OUTPUT_TAGS(ONE_HOT);
87 };
88 
89 template <class Context>
90 class BatchBucketOneHotOp final : public Operator<Context> {
91  public:
92  USE_OPERATOR_CONTEXT_FUNCTIONS;
93  BatchBucketOneHotOp(const OperatorDef& operator_def, Workspace* ws)
94  : Operator<Context>(operator_def, ws) {}
95 
96  bool RunOnDevice() override;
97 
98  protected:
99  INPUT_TAGS(X, LENS, BOUNDARIES);
100  OUTPUT_TAGS(ONE_HOT);
101 };
102 
103 } // namespace caffe2
104 
105 #endif // CAFFE_OPERATORS_ONE_HOT_OPS_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.