1 #ifndef CAFFE2_OPERATORS_INDEX_HASH_OPS_H_ 2 #define CAFFE2_OPERATORS_INDEX_HASH_OPS_H_ 4 #include "caffe2/core/asan.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
17 seed_(this->
template GetSingleArgument<int64_t>(
"seed", 0)),
18 modulo_(this->
template GetSingleArgument<int64_t>(
"modulo", 0)) {
19 CAFFE_ENFORCE_GT(modulo_, 0,
"MODULO should be > 0");
22 bool RunOnDevice()
override {
24 this,
Input(INDICES));
28 bool DoRunWithType() {
29 auto& indices =
Input(INDICES);
31 auto* hashed_indices =
32 Output(HASHED_INDICES, indices.sizes(), at::dtype<T>());
35 static_cast<int64_t>(std::numeric_limits<T>::max()),
37 "MODULO shouldn't be larger than the numeric limit of the indices");
39 auto N = indices.numel();
40 auto* indices_data = indices.template data<T>();
41 auto* hashed_indices_data = hashed_indices->template mutable_data<T>();
43 for (
auto i = 0; i < N; i++) {
44 hashed_indices_data[i] = hash(indices_data[i]);
52 CAFFE2_NO_SANITIZE(
"signed-integer-overflow")
T hash(
T id) {
53 int8_t* bytes = (int8_t*)&
id;
54 T hashed = seed_ * 0xDEADBEEF;
55 for (
int i = 0; i <
sizeof(
T) /
sizeof(int8_t); i++) {
56 hashed = hashed * 65537 + bytes[i];
60 auto modHashed = hashed % modulo_;
61 return modHashed >= 0 ? modHashed : modHashed + modulo_;
66 OUTPUT_TAGS(HASHED_INDICES);
74 #endif // CAFFE2_OPERATORS_INDEX_HASH_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...