Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_op_shared_gpu.cc
1 
17 #include "caffe2/core/context_gpu.h"
18 #include "conv_op_shared.h"
19 
20 namespace caffe2 {
21 
22 template <>
23 void createSharedBuffer<CUDAContext>(Workspace* ws) {
24  auto* mutexPtr = ws->CreateBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA_MUTEX__")
25  ->GetMutable<std::unique_ptr<std::mutex>>();
26  mutexPtr->reset(new std::mutex());
27  ws->CreateBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__");
28 }
29 
30 template <>
31 void runWithSharedBuffer(
32  Workspace* ws,
33  std::function<void(Tensor<CUDAContext>* buffer)> f) {
34  auto* mutexBlob = ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA_MUTEX__");
35  CAFFE_ENFORCE(mutexBlob, "Must call createSharedBuffer() first");
36 
37  auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
38  std::lock_guard<std::mutex> g(**mutexPtr);
39  auto* buffer = ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__")
40  ->GetMutable<TensorCUDA>();
41  f(buffer);
42 }
43 }
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.