3 #ifndef CAFFE2_CORE_CUDNN_WRAPPERS_H_ 4 #define CAFFE2_CORE_CUDNN_WRAPPERS_H_ 6 #include "caffe2/core/common_cudnn.h" 7 #include "caffe2/core/context_gpu.h" 63 void*
get(
size_t nbytes) {
64 if (nbytes_ < nbytes) {
66 data_ = CUDAContext::New(nbytes);
69 CAFFE_ENFORCE_GE(nbytes_, nbytes);
89 explicit CuDNNState(
size_t gpu_id) : gpu_id_(gpu_id) {
91 CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
92 CUDA_ENFORCE(cudaEventCreate(&before_));
93 CUDA_ENFORCE(cudaEventCreate(&after_));
94 CUDA_ENFORCE(cudaStreamCreate(&stream_));
95 CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_));
100 CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
101 CUDA_CHECK(cudaStreamDestroy(stream_));
102 CUDA_CHECK(cudaEventDestroy(after_));
103 CUDA_CHECK(cudaEventDestroy(before_));
106 cudnnHandle_t& cudnn_handle() {
107 return cudnn_handle_;
114 template <
typename F>
115 void execute(cudaStream_t stream, F&& f) {
116 CUDA_ENFORCE(cudaEventRecord(before_, stream));
117 CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0));
119 CUDA_ENFORCE(cudaEventRecord(after_, stream_));
120 CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0));
124 cudnnHandle_t cudnn_handle_{
nullptr};
125 cudaEvent_t before_{
nullptr};
126 cudaEvent_t after_{
nullptr};
127 cudaStream_t stream_{
nullptr};
155 return context_->cudnn_handle();
159 template <
typename F>
160 void with_cudnn_state(
size_t state_idx, F&& f) {
162 state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES,
"Invalid state_idx");
163 auto& sync_state = cudnn_states()[context_->device_id()][state_idx];
171 std::lock_guard<std::mutex> g(sync_state.mutex);
172 if (!sync_state.state.get()) {
173 sync_state.state.reset(
new CuDNNState(context_->device_id()));
175 CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
182 static constexpr
size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4;
186 std::unique_ptr<CuDNNState> state;
189 using PerGPUCuDNNStates = std::array<
190 std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
191 C10_COMPILE_TIME_MAX_GPUS>;
192 static PerGPUCuDNNStates& cudnn_states();
CuDNNWrapper(CUDAContext *context)
Creates a cudnn wrapper associated with a CUDAContext object.
Represents a a compute device on which a tensor is located.
CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn scratch space...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A variant of DeviceGuard that is specialized for CUDA.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.