5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
19 this->
template GetSingleArgument<int>(
"categorical_limit", 0)) {
20 CAFFE_ENFORCE_GT(categorical_limit_, 0);
23 bool RunOnDevice()
override {
24 auto& keys =
Input(0);
26 const T* keys_data = keys.template data<T>();
27 std::vector<int> counts(categorical_limit_);
28 std::vector<int*> eids(categorical_limit_);
29 for (
int k = 0; k < categorical_limit_; k++) {
32 for (
int i = 0; i < N; i++) {
34 CAFFE_ENFORCE_GT(categorical_limit_, k);
35 CAFFE_ENFORCE_GE(k, 0);
38 for (
int k = 0; k < categorical_limit_; k++) {
39 auto* eid = Output(k, {counts[k]}, at::dtype<int>());
40 eids[k] = eid->template mutable_data<int>();
43 for (
int i = 0; i < N; i++) {
45 eids[k][counts[k]++] = i;
51 int categorical_limit_;
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...