Caffe2 - C++ API
A deep learning, cross platform ML framework
cudnn_wrappers.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #ifndef CAFFE2_CORE_CUDNN_WRAPPERS_H_
4 #define CAFFE2_CORE_CUDNN_WRAPPERS_H_
5 
6 #include "caffe2/core/common_cudnn.h"
7 #include "caffe2/core/context_gpu.h"
8 
9 // Note [What is CuDNNWrapper good for?]
10 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
11 // Suppose you are writing a kernel that calls into CuDNN, and
12 // you need a cudnnHandle_t to pass to the kernel call. How should
13 // you go about getting one of those handles? You'd prefer not
14 // to make a new cudnnHandle_t every call; this can be somewhat
15 // expensive (1-2%, according to some measurements in TensorFlow.)
16 // But cudnnHandle_t is not thread-safe, so we can't just have
17 // a single global cudnnHandle_t that everyone uses.
18 //
19 // Thus, the most common method in Caffe2 for getting a CuDNN handle
20 // is to get a per-thread, per-stream CuDNN handle from CUDAContext
21 // (which knows what the current thread and stream are). The idiomatic
22 // way to do this in Caffe2 today is to make a CuDNNWrapper and then call
23 // inline_cudnn_handle(), although you didn't really need the
24 // CuDNNWrapper at all (you could have gotten it directly from
25 // CUDAContext.)
26 //
27 // So, what's all this business about CuDNNWrapper? In theory, it was
28 // designed with a more specialized use-case in mind, where you need to
29 // make multiple calls to CuDNN in parallel; e.g., when manually
30 // computing group convolution. By using with_cudnn_state(), you can
31 // get separate cudnnHandle_t and CUDA stream per parallel thread of
32 // execution, and run all of the cuDNN calls in parallel. CuDNNWrapper
33 // handles the business of synchronizing with the stream prior to this
34 // call.
35 //
36 // (By the way, this is why no such CUBLASWrapper exists; there isn't
37 // ever any reason you need to call cublas in parallel, since most
38 // cublas operations have batched variants.)
39 //
40 // Now, that's the theory... in practice, this is only ever used when
41 // multiple operators are run in parallel, and not to actually
42 // parallelize multiple CuDNN calls (for example, group convolution is
43 // now supported natively in CuDNN.) So... while the kit provided here
44 // might be useful for someone else in the future, it's not really used
45 // now. So we might consider deleting it, or unifying this mechanism
46 // with PyTorch's own CuDNN handle pool. (which is it's own thing.)
47 
48 namespace caffe2 {
49 
50 class CuDNNWrapper;
51 
61  ~CuDNNWorkspace() noexcept {}
62 
63  void* get(size_t nbytes) {
64  if (nbytes_ < nbytes) {
65  reset();
66  data_ = CUDAContext::New(nbytes);
67  nbytes_ = nbytes;
68  }
69  CAFFE_ENFORCE_GE(nbytes_, nbytes);
70  return data_.get();
71  }
72 
73  void reset() {
74  data_.clear();
75  nbytes_ = 0;
76  }
77 
78  private:
79  at::DataPtr data_{nullptr, nullptr, &NoDelete, at::Device(CUDA)};
80  size_t nbytes_{0};
81 };
82 
83 // CuDNNState is the owner of the CuDNNWorkspace, and serializes all
84 // executions of operations that use the state onto it's own stream
85 // (so multiple Net workers can reuse the same workspace from
86 // different threads and CUDA streams).
87 class CuDNNState {
88  public:
89  explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) {
90  CUDAGuard g(gpu_id_);
91  CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
92  CUDA_ENFORCE(cudaEventCreate(&before_));
93  CUDA_ENFORCE(cudaEventCreate(&after_));
94  CUDA_ENFORCE(cudaStreamCreate(&stream_));
95  CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_));
96  }
97 
98  ~CuDNNState() noexcept {
99  CUDAGuard g(gpu_id_);
100  CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
101  CUDA_CHECK(cudaStreamDestroy(stream_));
102  CUDA_CHECK(cudaEventDestroy(after_));
103  CUDA_CHECK(cudaEventDestroy(before_));
104  }
105 
106  cudnnHandle_t& cudnn_handle() {
107  return cudnn_handle_;
108  }
109 
110  CuDNNWorkspace& workspace() {
111  return workspace_;
112  }
113 
114  template <typename F>
115  void execute(cudaStream_t stream, F&& f) {
116  CUDA_ENFORCE(cudaEventRecord(before_, stream));
117  CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0));
118  f(this);
119  CUDA_ENFORCE(cudaEventRecord(after_, stream_));
120  CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0));
121  }
122 
123  private:
124  cudnnHandle_t cudnn_handle_{nullptr};
125  cudaEvent_t before_{nullptr};
126  cudaEvent_t after_{nullptr};
127  cudaStream_t stream_{nullptr};
128  CuDNNWorkspace workspace_;
129  size_t gpu_id_{0};
130  C10_DISABLE_COPY_AND_ASSIGN(CuDNNState);
131 };
132 
143  public:
148  explicit CuDNNWrapper(CUDAContext* context) : context_(context) {}
149 
154  cudnnHandle_t inline_cudnn_handle() {
155  return context_->cudnn_handle();
156  }
157 
158  // Executes the closure F on the CuDNNState associated with state_idx
159  template <typename F>
160  void with_cudnn_state(size_t state_idx, F&& f) {
161  CAFFE_ENFORCE(
162  state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx");
163  auto& sync_state = cudnn_states()[context_->device_id()][state_idx];
164 
165  CUDAGuard dg(context_->device_id());
166 
167  // We need to serialize execution on the CuDNNState as we can't
168  // allow multiple threads to race through the cudaEventRecord
169  // calls (so a worker thread might wait on another worker thread's
170  // execution)
171  std::lock_guard<std::mutex> g(sync_state.mutex);
172  if (!sync_state.state.get()) {
173  sync_state.state.reset(new CuDNNState(context_->device_id()));
174  }
175  CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
176  }
177 
178  protected:
179  // Pointer to an external cuda context that the cudnn wrapper will use.
180  CUDAContext* context_;
181 
182  static constexpr size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4;
183 
185  std::mutex mutex;
186  std::unique_ptr<CuDNNState> state;
187  };
188 
189  using PerGPUCuDNNStates = std::array<
190  std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
191  C10_COMPILE_TIME_MAX_GPUS>;
192  static PerGPUCuDNNStates& cudnn_states();
193 
194  C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);
195 };
196 
197 }; // namespace caffe2
198 
199 #endif
CuDNNWrapper(CUDAContext *context)
Creates a cudnn wrapper associated with a CUDAContext object.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn scratch space...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
A variant of DeviceGuard that is specialized for CUDA.
Definition: CUDAGuard.h:20
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.