Caffe2 - C++ API
A deep learning, cross platform ML framework
context_gpu.h
1 #ifndef CAFFE2_CORE_CONTEXT_GPU_H_
2 #define CAFFE2_CORE_CONTEXT_GPU_H_
3 
4 #include <ctime>
5 #include <mutex>
6 
7 #include "caffe2/core/common.h"
8 #include "caffe2/core/common_gpu.h"
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/context_base.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/numa.h"
13 #include "caffe2/core/tensor.h"
14 #include "caffe2/core/types.h"
15 #include "caffe2/proto/caffe2_pb.h"
16 
17 // Since we are using the macro CAFFE2_USE_CUDNN, we will need to include this
18 // file after common.h is included.
19 #ifdef CAFFE2_USE_CUDNN
20 #include "caffe2/core/common_cudnn.h"
21 #endif // CAFFE2_USE_CUDNN
22 
23 #include <c10/core/Device.h>
24 #include <c10/core/Stream.h>
25 #include <c10/cuda/CUDAStream.h>
26 #include <c10/cuda/CUDAGuard.h>
27 
28 namespace caffe2 {
29 
30 enum class CudaMemoryPoolType {
31  NONE = 0,
32  CUB = 1,
33  THC = 2,
34 };
35 
41 CAFFE2_CUDA_API CudaMemoryPoolType GetCudaMemoryPoolType();
42 
56 class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
57  friend class CUDAContext;
58 
59  private:
61  for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
62  cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
63  }
64  }
65 
66  // Record current stream id for the current thread.
67  // This is the new API we're trying to migrate use cases to and get rid of
68  // explicit stream id passing. For now it's invoked in
69  // CUDAContext::SwitchToDevice
70  void SetCurrentStreamId(DeviceIndex gpu, StreamId stream_id) {
71  // TODO: use current device id from thread local instead of passing gpu in
72  c10::cuda::setCurrentCUDAStream(GetCUDAStream(gpu, stream_id));
73  }
74 
75  // Retrieves the CUDAStream corresponding to a logical stream ID, ensuring
76  // that it exists in cuda_streams_ if it has not been allocated yet.
77  c10::cuda::CUDAStream GetCUDAStream(DeviceIndex gpu, StreamId stream_id) {
78  vector<c10::cuda::CUDAStream>& gpu_streams = cuda_streams_[gpu];
79  while (gpu_streams.size() <= static_cast<size_t>(stream_id)) {
80  // NB: This streams are not guaranteed to be unique; we'll
81  // wrap around once we run out of streams in the pool.
82  gpu_streams.emplace_back(c10::cuda::getStreamFromPool(/* high priority */ false, gpu));
83  }
84  return gpu_streams[stream_id];
85  }
86 
87  // Uses the logical stream id from the thread local to pick the stream
88  // We're going to migrate all usages to this case API instead of passing the
89  // stream id directly
90  cudaStream_t GetStream(DeviceIndex gpu) {
91  return c10::cuda::getCurrentCUDAStream(gpu).stream();
92  }
93 
94  cudaStream_t GetStream(DeviceIndex gpu, StreamId stream_id) {
95  return GetCUDAStream(gpu, stream_id).stream();
96  }
97 
98  // Uses the logical stream id from the thread local to pick the stream
99  // We're going to migrate all usages to this case API instead of passing the
100  // stream id directly
101  cublasHandle_t GetHandle(DeviceIndex gpu) {
102  return GetHandle(c10::cuda::getCurrentCUDAStream(gpu));
103  }
104 
105  cublasHandle_t GetHandle(c10::cuda::CUDAStream cuda_stream) {
106  CUDAGuard guard(cuda_stream.device_index());
107  // Default construct in the map if it doesn't exist, and return a mutable
108  // refernce to it.
109  auto& r = cublas_handles_[cuda_stream];
110  if (r == nullptr) {
111  CUBLAS_ENFORCE(cublasCreate(&r));
112  // The default is CUBLAS_POINTER_MODE_HOST. You can override
113  // it after obtaining the cublas handle, but do that with
114  // caution.
115  CUBLAS_ENFORCE(cublasSetPointerMode(r, CUBLAS_POINTER_MODE_HOST));
116  CUBLAS_ENFORCE(cublasSetStream(r, cuda_stream));
117  }
118  return r;
119  }
120 
121 #ifdef CAFFE2_USE_CUDNN
122  // Uses the logical stream id from the thread local to pick the stream
123  // We're going to migrate all usages to this case API instead of passing the
124  // stream id directly
125  cudnnHandle_t GetCudnnHandle(DeviceIndex gpu) {
126  return GetCudnnHandle(c10::cuda::getCurrentCUDAStream(gpu));
127  }
128 
129  cudnnHandle_t GetCudnnHandle(c10::cuda::CUDAStream cuda_stream) {
130  CUDAGuard guard(cuda_stream.device_index());
131  auto& r = cudnn_handles_[cuda_stream];
132  if (r == nullptr) {
133  CUDNN_ENFORCE(cudnnCreate(&r));
134  CUDNN_ENFORCE(cudnnSetStream(r, cuda_stream));
135  }
136  return r;
137  }
138 #endif // CAFFE2_USE_CUDNN
139 
140  ~ThreadLocalCUDAObjects() noexcept {
141  for (auto element : cublas_handles_) {
142  if (element.second) {
143  CUBLAS_CHECK(cublasDestroy(element.second));
144  }
145  }
146 #ifdef CAFFE2_USE_CUDNN
147  for (auto element : cudnn_handles_) {
148  if (element.second) {
149  CUDNN_CHECK(cudnnDestroy(element.second));
150  }
151  }
152 #endif // CAFFE2_USE_CUDNN
153  }
154  // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
155  // is NOT bijective; multiple logical stream IDs may map to the
156  // same underlying stream ID.
157  vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
158  std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
159 #ifdef CAFFE2_USE_CUDNN
160  std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;
161 #endif // CAFFE2_USE_CUDNN
162 };
163 
164 class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
165  public:
166  // The default cuda context constructor.
167  explicit CUDAContext(DeviceIndex gpu_id = -1);
168  explicit CUDAContext(const DeviceOption& option);
169  explicit CUDAContext(Device device)
170  : CUDAContext(DeviceToOption(device)) {}
171 
172  ~CUDAContext() override {
173  if (curand_generator_) {
174  CURAND_CHECK(curandDestroyGenerator(curand_generator_));
175  }
176  // CUDAContext is used in 2 cases now:
177  // - long-lived instance inside OperatorBase in which case what happens in
178  // destructor doesn't really matter
179  // - short-lived on-the-fly instances that are utilized as CUDAGuard - in
180  // this case there's only one stream id (passed to SwitchToDevice) and
181  // it's preferrable to synchronize in the destructor
182  FinishDeviceComputation();
183  }
184 
185  inline void SwitchToDevice(StreamId stream_id) override {
186  getCudaObjects().SetCurrentStreamId(gpu_id_, stream_id);
187  CaffeCudaSetDevice(gpu_id_);
188  }
189 
190  // void SwitchToDevice()
191  using BaseContext::SwitchToDevice;
192 
193  inline void WaitEvent(const Event& ev) override {
194  ev.Wait(CUDA, this);
195  }
196 
197  inline void Record(Event* ev, const char* err_msg = nullptr) const override {
198  CAFFE_ENFORCE(ev, "Event must not be null.");
199  ev->Record(CUDA, this, err_msg);
200  }
201 
202  // Note on current use cases:
203  // FinishDeviceComputation must be called on the same cpu thread as
204  // SwitchToDevice()
205  void FinishDeviceComputation() override {
206  CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_)));
207  cudaError_t error = cudaGetLastError();
208  if (error != cudaSuccess) {
209  CAFFE_THROW("Encountered CUDA error: ", cudaGetErrorString(error));
210  }
211  }
212 
213  inline int device_id() const {
214  return gpu_id_;
215  }
216 
217  inline cudaStream_t cuda_stream() const {
218  return getCudaObjects().GetStream(gpu_id_);
219  }
220 
221  static cudaStream_t cuda_stream(DeviceIndex gpu_id, StreamId stream_id) {
222  return getCudaObjects().GetStream(gpu_id, stream_id);
223  }
224 
225  cublasHandle_t cublas_handle() {
226  return getCudaObjects().GetHandle(gpu_id_);
227  }
228 
229 #ifdef CAFFE2_USE_CUDNN
230  cudnnHandle_t cudnn_handle() {
231  return getCudaObjects().GetCudnnHandle(gpu_id_);
232  }
233 #endif // CAFFE2_USE_CUDNN
234 
235  curandGenerator_t& curand_generator() {
236  if (!curand_generator_) {
237  CUDAGuard guard(gpu_id_);
238  CURAND_ENFORCE(
239  curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
240  CURAND_ENFORCE(
241  curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_));
242  CHECK_NOTNULL(curand_generator_);
243  }
244  CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream()));
245  return curand_generator_;
246  }
247 
248  inline static at::DataPtr New(size_t nbytes) {
249  return GetAllocator(CUDA)->allocate(nbytes);
250  }
251 
252  // Get a mutex to lock out cudaMalloc / cudaFree calls when
253  // NCCL kernels are being launched. Should remove threat of
254  // deadlocks
255  static std::mutex& mutex();
256 
257  // Functions to query memory stats. Only available if flag
258  // --caffe2_gpu_memory_tracking is enabled.
259  static std::vector<long> TotalMemoryByGpu();
260  static std::vector<long> MaxMemoryByGpu();
261 
262  template <class SrcContext, class DstContext>
263  inline void CopyBytes(size_t nbytes, const void* src, void* dst) {
264  CUDA_ENFORCE(cudaMemcpyAsync(
265  dst,
266  src,
267  nbytes,
268  cudaMemcpyDefault,
269  getCudaObjects().GetStream(gpu_id_)));
270  }
271 
272  void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override {
273  CopyBytes<CUDAContext, CUDAContext>(nbytes, src, dst);
274  }
275 
276  void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override {
277  CopyBytes<CUDAContext, CPUContext>(nbytes, src, dst);
278  }
279 
280  void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override {
281  CopyBytes<CPUContext, CUDAContext>(nbytes, src, dst);
282  }
283 
284  template <typename T, class SrcContext, class DstContext>
285  inline void Copy(int n, const T* src, T* dst) {
286  CopyBytes<SrcContext, DstContext>(n * sizeof(T),
287  static_cast<const void*>(src),
288  static_cast<void*>(dst));
289  }
290 
291  template <class SrcContext, class DstContext>
292  inline void
293  CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) {
294  CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types.");
295  CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
296  }
297 
298  static void CopyBytesAsync(
299  size_t nbytes,
300  const void* src,
301  Device src_device,
302  void* dst,
303  Device dst_device);
304  static void CopyBytesSync(
305  size_t nbytes,
306  const void* src,
307  Device src_device,
308  void* dst,
309  Device dst_device);
310 
311  // By default CUDA operators have async device parts
312  static bool HasAsyncPartDefault() {
313  return true;
314  }
315 
316  static bool SupportsAsyncScheduling() {
317  return true;
318  }
319 
320  static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) {
321  auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id);
322  return cudaStreamQuery(stream) == cudaSuccess;
323  }
324 
325  at::Device device() const override {
326  return at::Device(CUDA, gpu_id_);
327  }
328 
329  DeviceType device_type() const override {
330  return CUDA;
331  }
332 
333  static constexpr DeviceType GetDeviceType() {
334  return CUDA;
335  }
336 
337  protected:
338  int gpu_id_;
339  int random_seed_;
340  curandGenerator_t curand_generator_{nullptr};
341  static ThreadLocalCUDAObjects& getCudaObjects();
342 };
343 
344 using TensorCUDA = Tensor;
345 
346 } // namespace caffe2
347 
348 #endif // CAFFE2_CORE_CONTEXT_GPU_H_
constexpr size_t itemsize() const noexcept
Returns the size of the item.
Definition: typeid.h:365
A struct to host thread-local cuda objects.
Definition: context_gpu.h:56
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
constexpr Copy * copy() const noexcept
Returns the typed copy function pointer for individual iterms.
Definition: typeid.h:380
int32_t StreamId
An index representing a specific stream.
Definition: Stream.h:15
CAFFE2_CUDA_API CudaMemoryPoolType GetCudaMemoryPoolType()
Gets the current memory pool type used by Caffe2.
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
A variant of DeviceGuard that is specialized for CUDA.
Definition: CUDAGuard.h:20
void CaffeCudaSetDevice(const int id)
Gets the current GPU id.
Definition: common_gpu.cc:102
cudaStream_t stream() const
Explicit conversion to cudaStream_t.
Definition: CUDAStream.cpp:318
DeviceIndex device_index() const
Get the CUDA device index that this stream is associated with.
Definition: CUDAStream.h:95