1 #ifndef CAFFE_OPERATORS_ONE_HOT_OPS_H_ 2 #define CAFFE_OPERATORS_ONE_HOT_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
20 bool RunOnDevice()
override {
21 auto& indices =
Input(0);
25 "indices input must be 1D tensor of data type int64_t");
28 auto& index_size_tensor = this->
template Input<Tensor>(1, CPU);
30 index_size_tensor.numel(),
32 "index_size_tensor input must be scalar of data type int64_t");
34 auto batch_size = indices.numel();
35 auto index_size = *index_size_tensor.template data<int64_t>();
36 auto one_hots = Output(0);
37 one_hots->Resize(batch_size, index_size);
38 auto output_size = one_hots->numel();
39 if (output_size == 0) {
43 DoOneHotOp(batch_size, index_size, indices, one_hots);
55 template <
class Context>
58 USE_OPERATOR_CONTEXT_FUNCTIONS;
59 template <
class... Args>
63 bool RunOnDevice()
override {
70 INPUT_TAGS(X, LENS, VALS);
77 std::vector<int64_t> valsOffsets_;
80 template <
class Context>
83 USE_OPERATOR_CONTEXT_FUNCTIONS;
84 template <
class... Args>
88 bool RunOnDevice()
override;
91 INPUT_TAGS(X, LENS, BOUNDARIES);
97 #endif // CAFFE_OPERATORS_ONE_HOT_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...