Caffe2 - C++ API
A deep learning, cross platform ML framework
fbgemm_pack_matrix_cache.cc
1 #include "fbgemm_pack_matrix_cache.h"
2 
3 #include <map>
4 #include <memory>
5 #include <mutex>
6 
7 using namespace std;
8 
9 namespace caffe2 {
10 
11 template <typename ACC_T>
12 shared_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>> GetOrCreateFbgemmPackBMatrix(
13  fbgemm::matrix_op_t trans,
14  int32_t m,
15  int32_t n,
16  const void* orig_data,
17  const int8_t* quantized_data,
18  int32_t ld) {
19  static std::map<
20  std::tuple<int, int, const void*>,
21  weak_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>>>
22  cache;
23  static mutex cache_mutex;
24 
25  // Create a new packed matrix and compare with cached one if there's any.
26  // Note that a cache miss is as expensive as a cache hit here, the purpose of
27  // this cache is only to deduplicate the quantized tensors for improved
28  // memory bandwidth if different nets share copies of the same operator.
29  // TODO: make this cheaper by computing hash of fdata.
30  auto new_packed = make_shared<fbgemm::PackBMatrix<int8_t, ACC_T>>(
31  trans,
32  m,
33  n,
34  quantized_data,
35  ld,
36  nullptr, // pmat
37  1); // groups
38 
39  std::tuple<int, int, const void*> key(m, n, orig_data);
40  std::shared_ptr<fbgemm::PackBMatrix<int8_t, ACC_T>> cache_entry;
41  {
42  lock_guard<mutex> lock(cache_mutex);
43  auto itr = cache.find(key);
44  if (itr != cache.end()) {
45  cache_entry = itr->second.lock();
46  }
47  } // release lock here during expensive equals()
48 
49  if (!cache_entry || !cache_entry->metaEquals(*new_packed) ||
50  !cache_entry->equals(*new_packed)) {
51  // cache miss
52  lock_guard<mutex> lock(cache_mutex);
53  cache[key] = new_packed;
54  return new_packed;
55  } else {
56  return cache_entry;
57  }
58 }
59 
60 template shared_ptr<fbgemm::PackBMatrix<int8_t, int16_t>>
61 GetOrCreateFbgemmPackBMatrix<int16_t>(
62  fbgemm::matrix_op_t trans,
63  int32_t m,
64  int32_t n,
65  const void* orig_data,
66  const int8_t* quantized_data,
67  int32_t ld);
68 
69 template shared_ptr<fbgemm::PackBMatrix<int8_t, int32_t>>
70 GetOrCreateFbgemmPackBMatrix<int32_t>(
71  fbgemm::matrix_op_t trans,
72  int32_t m,
73  int32_t n,
74  const void* orig_data,
75  const int8_t* quantized_data,
76  int32_t ld);
77 
78 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13