Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_op_cache_cudnn.h
1 #ifndef CAFFE2_OPERATORS_CONV_OP_CACHE_H_
2 #define CAFFE2_OPERATORS_CONV_OP_CACHE_H_
3 
4 #include <functional>
5 #include <unordered_map>
6 #include <vector>
7 
8 #include "caffe2/core/logging.h"
9 #include "caffe2/core/tensor.h"
10 
11 namespace caffe2 {
12 template <typename TAlgorithm>
14  public:
15  // Caches the best algorithm for a given
16  // combination of tensor dimensions & compute data type.
17  //
18  TAlgorithm getAlgorithm(
19  at::IntArrayRef tensorDimensions1,
20  at::IntArrayRef tensorDimensions2,
21  int algorithmFlags, // Differentiate between algorithms with different
22  // parameters in a generic way
23  std::function<TAlgorithm()> generatingFunc);
24 
25  private:
26  std::unordered_map<int64_t, TAlgorithm> hash_;
27 };
28 
29 template <typename TAlgorithm>
31  at::IntArrayRef tensorDimensions1,
32  at::IntArrayRef tensorDimensions2,
33  int algorithmFlags,
34  std::function<TAlgorithm()> generatingFunc) {
35  int64_t seed = 0;
36  // Hash all of the inputs, which we wiill then use to try and look up
37  // a previously discovered algorithm, or fall back to generating a new one.
38  std::hash<int64_t> hashFn;
39  for (const auto num : tensorDimensions1) {
40  // Copied from boost::hash_combine.
41  // Adding 1 to differentiate between first and second vector.
42  seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
43  }
44 
45  for (const auto num : tensorDimensions2) {
46  // Copied from boost::hash_combine.
47  seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
48  }
49 
50  // Adding 2 to differentiate from previous vectors
51  seed ^= hashFn(algorithmFlags) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 2;
52 
53  if (seed == 0) {
54  return generatingFunc();
55  }
56 
57  if (hash_.find(seed) == hash_.end()) {
58  TAlgorithm value = generatingFunc();
59  hash_[seed] = value;
60  }
61 
62  return hash_[seed];
63 }
64 } // namespace caffe2
65 
66 #endif
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13