Caffe2 - C++ API
A deep learning, cross platform ML framework
cudnn_wrappers.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #ifndef CAFFE2_CORE_CUDNN_WRAPPERS_H_
4 #define CAFFE2_CORE_CUDNN_WRAPPERS_H_
5 
6 #include "caffe2/core/common_cudnn.h"
7 #include "caffe2/core/context_gpu.h"
8 
9 namespace caffe2 {
10 
11 class CuDNNWrapper;
12 
22  ~CuDNNWorkspace() noexcept {}
23 
24  void* get(size_t nbytes) {
25  if (nbytes_ < nbytes) {
26  reset();
27  auto data_and_deleter = CUDAContext::New(nbytes);
28  data_ = {data_and_deleter.first, data_and_deleter.second};
29  nbytes_ = nbytes;
30  }
31  CAFFE_ENFORCE_GE(nbytes_, nbytes);
32  return data_.get();
33  }
34 
35  void reset() {
36  data_ = nullptr;
37  nbytes_ = 0;
38  }
39 
40  private:
41  std::unique_ptr<void, MemoryDeleter> data_{nullptr, NoDelete};
42  size_t nbytes_{0};
43 };
44 
45 // CuDNNState is the owner of the CuDNNWorkspace, and serializes all
46 // executions of operations that use the state onto it's own stream
47 // (so multiple Net workers can reuse the same workspace from
48 // different threads and CUDA streams).
49 class CuDNNState {
50  public:
51  explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) {
52  DeviceGuard g(gpu_id_);
53  CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
54  CUDA_ENFORCE(cudaEventCreate(&before_));
55  CUDA_ENFORCE(cudaEventCreate(&after_));
56  CUDA_ENFORCE(cudaStreamCreate(&stream_));
57  CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_));
58  }
59 
60  ~CuDNNState() noexcept {
61  DeviceGuard g(gpu_id_);
62  CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
63  CUDA_CHECK(cudaStreamDestroy(stream_));
64  CUDA_CHECK(cudaEventDestroy(after_));
65  CUDA_CHECK(cudaEventDestroy(before_));
66  }
67 
68  cudnnHandle_t& cudnn_handle() {
69  return cudnn_handle_;
70  }
71 
72  CuDNNWorkspace& workspace() {
73  return workspace_;
74  }
75 
76  template <typename F>
77  void execute(cudaStream_t stream, F&& f) {
78  CUDA_ENFORCE(cudaEventRecord(before_, stream));
79  CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0));
80  f(this);
81  CUDA_ENFORCE(cudaEventRecord(after_, stream_));
82  CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0));
83  }
84 
85  private:
86  cudnnHandle_t cudnn_handle_{nullptr};
87  cudaEvent_t before_{nullptr};
88  cudaEvent_t after_{nullptr};
89  cudaStream_t stream_{nullptr};
90  CuDNNWorkspace workspace_;
91  size_t gpu_id_{0};
92  DISABLE_COPY_AND_ASSIGN(CuDNNState);
93 };
94 
105  public:
110  explicit CuDNNWrapper(CUDAContext* context) : context_(context) {}
111 
116  cudnnHandle_t inline_cudnn_handle() {
117  return context_->cudnn_handle();
118  }
119 
120  // Executes the closure F on the CuDNNState associated with state_idx
121  template <typename F>
122  void with_cudnn_state(size_t state_idx, F&& f) {
123  CAFFE_ENFORCE(
124  state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx");
125  auto& sync_state = cudnn_states()[context_->cuda_gpu_id()][state_idx];
126 
127  DeviceGuard dg(context_->cuda_gpu_id());
128 
129  // We need to serialize execution on the CuDNNState as we can't
130  // allow multiple threads to race through the cudaEventRecord
131  // calls (so a worker thread might wait on another worker thread's
132  // execution)
133  std::lock_guard<std::mutex> g(sync_state.mutex);
134  if (!sync_state.state.get()) {
135  sync_state.state.reset(new CuDNNState(context_->cuda_gpu_id()));
136  }
137  CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
138  }
139 
140  protected:
141  // Pointer to an external cuda context that the cudnn wrapper will use.
142  CUDAContext* context_;
143 
144  static constexpr size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4;
145 
147  std::mutex mutex;
148  std::unique_ptr<CuDNNState> state;
149  };
150 
151  using PerGPUCuDNNStates = std::array<
152  std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
153  CAFFE2_COMPILE_TIME_MAX_GPUS>;
154  static PerGPUCuDNNStates& cudnn_states();
155 
156  DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);
157 };
158 
159 }; // namespace caffe2
160 
161 #endif
CuDNNWrapper(CUDAContext *context)
Creates a cudnn wrapper associated with a CUDAContext object.
CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn scratch space...
Copyright (c) 2016-present, Facebook, Inc.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.