Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_context.h
1 
17 #ifndef CAFFE2_UTILS_MKL_CONTEXT_H_
18 #define CAFFE2_UTILS_MKL_CONTEXT_H_
19 
20 #include <cstdlib>
21 #include <ctime>
22 #include <random>
23 
24 #include "caffe2/core/context.h"
25 
26 namespace caffe2 {
27 
36 class MKLContext final {
37  public:
38  MKLContext() : random_seed_(RandomNumberSeed()) {}
39  explicit MKLContext(const DeviceOption& option)
40  : random_seed_(
41  option.has_random_seed() ? option.random_seed()
42  : RandomNumberSeed()) {
43  CAFFE_ENFORCE_EQ(option.device_type(), MKLDNN);
44  }
45 
46  ~MKLContext() {}
47 
48  inline void SwitchToDevice(int /*stream_id*/ = 0) {}
49 
50  inline void WaitEvent(const Event& ev) {
51  ev.Wait(MKLDNN, this);
52  }
53 
54  inline void Record(Event* ev, const char* err_msg = nullptr) const {
55  CAFFE_ENFORCE(ev, "Event must not be null.");
56  ev->Record(MKLDNN, this, err_msg);
57  }
58 
59  inline void FinishDeviceComputation() {}
60 
61  inline std::mt19937& RandGenerator() {
62  if (!random_generator_.get()) {
63  random_generator_.reset(new std::mt19937(random_seed_));
64  }
65  return *random_generator_.get();
66  }
67 
68  inline static std::pair<void*, MemoryDeleter> New(size_t nbytes) {
69  return GetCPUAllocator()->New(nbytes);
70  }
71 
72  // Two copy functions that deals with cross-device copies.
73  template <class SrcContext, class DstContext>
74  inline void CopyBytes(size_t nbytes, const void* src, void* dst);
75 
76  template <typename T, class SrcContext, class DstContext>
77  inline void Copy(size_t n, const T* src, T* dst) {
78  if (std::is_fundamental<T>::value) {
79  CopyBytes<SrcContext, DstContext>(
80  n * sizeof(T),
81  static_cast<const void*>(src),
82  static_cast<void*>(dst));
83  } else {
84  for (int i = 0; i < n; ++i) {
85  dst[i] = src[i];
86  }
87  }
88  }
89 
90  template <class SrcContext, class DstContext>
91  inline void
92  CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) {
93  if (meta.copy()) {
94  meta.copy()(src, dst, n);
95  } else {
96  CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
97  }
98  }
99 
100  // By default MKL operators don't have async device parts
101  static bool HasAsyncPartDefault() {
102  return false;
103  }
104 
105  static bool SupportsAsyncScheduling() {
106  return false;
107  }
108 
109  static bool IsStreamFree(const DeviceOption& /* unused */, int /* unused */) {
110  return true;
111  }
112 
113  protected:
114  // TODO(jiayq): instead of hard-coding a generator, make it more flexible.
115  int random_seed_{1701};
116  std::unique_ptr<std::mt19937> random_generator_;
117 };
118 
119 template <>
120 inline void MKLContext::CopyBytes<MKLContext, MKLContext>(
121  size_t nbytes,
122  const void* src,
123  void* dst) {
124  memcpy(dst, src, nbytes);
125 }
126 
127 template <>
128 inline void MKLContext::CopyBytes<CPUContext, MKLContext>(
129  size_t nbytes,
130  const void* src,
131  void* dst) {
132  memcpy(dst, src, nbytes);
133 }
134 
135 template <>
136 inline void MKLContext::CopyBytes<MKLContext, CPUContext>(
137  size_t nbytes,
138  const void* src,
139  void* dst) {
140  memcpy(dst, src, nbytes);
141 }
142 } // namespace caffe2
143 
144 #endif // CAFFE2_UTILS_MKL_CONTEXT_H_
The MKL Context, which is largely the same as the CPUContext.
Definition: mkl_context.h:36
Copyright (c) 2016-present, Facebook, Inc.
TypedCopy copy() const
Returns the typed copy function pointer for individual iterms.
Definition: typeid.h:171
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:104
const size_t & itemsize() const
Returns the size of the item.
Definition: typeid.h:159
uint32_t RandomNumberSeed()
A function to generate a random number seed that is unique in a best-effort basis, using an ever-incrementing seed and the current time.
Definition: context.cc:26