Caffe2 - C++ API
A deep learning, cross platform ML framework
kernel_cache.cpp
1 #include <torch/csrc/jit/fuser/kernel_cache.h>
2 #include <torch/csrc/jit/passes/canonicalize.h>
3 #include <torch/csrc/jit/passes/shape_analysis.h>
4 
5 #include <cstdint>
6 #include <mutex>
7 #include <unordered_map>
8 
9 namespace torch {
10 namespace jit {
11 namespace fuser {
12 
14  // Note: std::unordered_map does not invalidate references even if rehashing
15  // occurs. This is a critical property for thread-safety.
16  std::mutex mutex_;
17  int64_t kernel_counter{0};
18 
19  // Map of fusion key to KernelSpec
20  std::unordered_map<int64_t, KernelSpec> specMap_;
21 
22  // Map of pretty-printed graph string to fusion key
23  // Used to check if a graph has already been cached in specMap_
24  std::unordered_map<std::string, int64_t> graphToKey_;
25 };
26 
27 static KernelCacheImpl& getKernelCache() {
28  static KernelCacheImpl cache;
29  return cache;
30 }
31 
32 int64_t debugNumCachedKernelSpecs() {
33  auto& cache = getKernelCache();
34  std::lock_guard<std::mutex> guard{cache.mutex_};
35  return cache.specMap_.size();
36 }
37 
38 std::shared_ptr<Graph> normalizeGraphForCache(
39  const std::shared_ptr<Graph>& graph) {
40  auto result = Canonicalize(graph, /*keep_unique_names=*/false);
41  EraseShapeInformation(result);
42  return result;
43 }
44 
45 // TODO: lookup by historic string key to start, then issue key
46 // as appropriate for faster lookup in the future
47 // precondition: graph has been normalized via normalizeGraphForCache
48 int64_t store(std::shared_ptr<Graph> graph) {
49  auto& cache = getKernelCache();
50  std::string repr = graph->toString();
51 
52  std::lock_guard<std::mutex> guard{cache.mutex_};
53  const auto key = cache.kernel_counter++;
54  cache.specMap_.emplace(
55  std::piecewise_construct,
56  std::forward_as_tuple(key),
57  std::forward_as_tuple(key, graph));
58  cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
59  return key;
60 }
61 
62 // XXX: Does not grab mutex
63 static at::optional<KernelSpec*> nolock_retrieve(
64  KernelCacheImpl& cache,
65  const int64_t key) {
66  auto it = cache.specMap_.find(key);
67  if (it == cache.specMap_.end())
68  return at::nullopt;
69  return &(it->second);
70 }
71 
72 at::optional<KernelSpec*> retrieve(const int64_t key) {
73  auto& cache = getKernelCache();
74  std::lock_guard<std::mutex> guard{cache.mutex_};
75  return nolock_retrieve(cache, key);
76 }
77 
78 // precondition: graph has been normalized via normalizeGraphForCache
79 at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
80  auto& cache = getKernelCache();
81  std::string repr = graph->toString();
82 
83  std::lock_guard<std::mutex> guard{cache.mutex_};
84  auto it = cache.graphToKey_.find(repr);
85  if (it == cache.graphToKey_.end())
86  return at::nullopt;
87  return nolock_retrieve(cache, it->second);
88 }
89 
90 } // namespace fuser
91 } // namespace jit
92 } // namespace torch
Definition: jit_type.h:17