1 #include <torch/csrc/jit/fuser/kernel_cache.h> 2 #include <torch/csrc/jit/passes/canonicalize.h> 3 #include <torch/csrc/jit/passes/shape_analysis.h> 7 #include <unordered_map> 17 int64_t kernel_counter{0};
20 std::unordered_map<int64_t, KernelSpec> specMap_;
24 std::unordered_map<std::string, int64_t> graphToKey_;
32 int64_t debugNumCachedKernelSpecs() {
33 auto& cache = getKernelCache();
34 std::lock_guard<std::mutex> guard{cache.mutex_};
35 return cache.specMap_.size();
38 std::shared_ptr<Graph> normalizeGraphForCache(
39 const std::shared_ptr<Graph>& graph) {
40 auto result = Canonicalize(graph,
false);
41 EraseShapeInformation(result);
48 int64_t store(std::shared_ptr<Graph> graph) {
49 auto& cache = getKernelCache();
50 std::string repr = graph->toString();
52 std::lock_guard<std::mutex> guard{cache.mutex_};
53 const auto key = cache.kernel_counter++;
54 cache.specMap_.emplace(
55 std::piecewise_construct,
56 std::forward_as_tuple(key),
57 std::forward_as_tuple(key, graph));
58 cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
66 auto it = cache.specMap_.find(key);
67 if (it == cache.specMap_.end())
73 auto& cache = getKernelCache();
74 std::lock_guard<std::mutex> guard{cache.mutex_};
75 return nolock_retrieve(cache, key);
80 auto& cache = getKernelCache();
81 std::string repr = graph->toString();
83 std::lock_guard<std::mutex> guard{cache.mutex_};
84 auto it = cache.graphToKey_.find(repr);
85 if (it == cache.graphToKey_.end())
87 return nolock_retrieve(cache, it->second);