Caffe2 - C++ API
A deep learning, cross platform ML framework
miopen_wrapper.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 #ifndef CAFFE2_CORE_MIOPEN_WRAPPERS_H_
3 #define CAFFE2_CORE_MIOPEN_WRAPPERS_H_
4 
5 #include "caffe2/core/hip/common_miopen.h"
6 #include "caffe2/core/hip/context_gpu.h"
7 
8 #include <c10/hip/HIPGuard.h>
9 
10 namespace caffe2 {
11 
12 class MIOPENWrapper;
13 
23 {
24  ~MIOPENWorkspace() noexcept {}
25 
26  void* get(size_t nbytes)
27  {
28  if(nbytes_ < nbytes)
29  {
30  reset();
31  data_ = HIPContext::New(nbytes);
32  nbytes_ = nbytes;
33  }
34  CAFFE_ENFORCE_GE(nbytes_, nbytes);
35  return data_.get();
36  }
37 
38  void reset()
39  {
40  data_.clear();
41  nbytes_ = 0;
42  }
43 
44  private:
45  at::DataPtr data_;
46  size_t nbytes_{0};
47 };
48 
49 // MIOPENState is the owner of the MIOPENWorkspace, and serializes all
50 // executions of operations that use the state onto it's own stream
51 // (so multiple Net workers can reuse the same workspace from
52 // different threads and HIP streams).
54 {
55  public:
56  explicit MIOPENState(size_t gpu_id) : gpu_id_(gpu_id)
57  {
58  HIPGuard g(gpu_id_);
59  MIOPEN_ENFORCE(miopenCreate(&miopen_handle_));
60  HIP_ENFORCE(hipEventCreate(&before_));
61  HIP_ENFORCE(hipEventCreate(&after_));
62  HIP_ENFORCE(hipStreamCreate(&stream_));
63  MIOPEN_ENFORCE(miopenSetStream(miopen_handle_, stream_));
64  }
65 
66  ~MIOPENState() noexcept
67  {
68  HIPGuard g(gpu_id_);
69  MIOPEN_CHECK(miopenDestroy(miopen_handle_));
70  HIP_CHECK(hipStreamDestroy(stream_));
71  HIP_CHECK(hipEventDestroy(after_));
72  HIP_CHECK(hipEventDestroy(before_));
73  }
74 
75  miopenHandle_t& miopen_handle() { return miopen_handle_; }
76 
77  MIOPENWorkspace& workspace() { return workspace_; }
78 
79  template <typename F>
80  void execute(hipStream_t stream, F&& f)
81  {
82  HIP_ENFORCE(hipEventRecord(before_, stream));
83  HIP_ENFORCE(hipStreamWaitEvent(stream_, before_, 0));
84  f(this);
85  HIP_ENFORCE(hipEventRecord(after_, stream_));
86  HIP_ENFORCE(hipStreamWaitEvent(stream, after_, 0));
87  }
88 
89  private:
90  miopenHandle_t miopen_handle_{nullptr};
91  hipEvent_t before_{nullptr};
92  hipEvent_t after_{nullptr};
93  hipStream_t stream_{nullptr};
94  MIOPENWorkspace workspace_;
95  size_t gpu_id_{0};
96  C10_DISABLE_COPY_AND_ASSIGN(MIOPENState);
97 };
98 
109 {
110  public:
115  explicit MIOPENWrapper(HIPContext* context) : context_(context) {}
116 
121  miopenHandle_t inline_miopen_handle() { return context_->miopen_handle(); }
122 
123  // Executes the closure F on the MIOPENState associated with state_idx
124  template <typename F>
125  void with_miopen_state(size_t state_idx, F&& f)
126  {
127  CAFFE_ENFORCE(state_idx < CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES, "Invalid state_idx");
128  auto& sync_state = miopen_states()[context_->device_id()][state_idx];
129 
130  HIPGuard dg(context_->device_id());
131 
132  // We need to serialize execution on the MIOPENState as we can't
133  // allow multiple threads to race through the cudaEventRecord
134  // calls (so a worker thread might wait on another worker thread's
135  // execution)
136  std::lock_guard<std::mutex> g(sync_state.mutex);
137  if(!sync_state.state.get())
138  {
139  sync_state.state.reset(new MIOPENState(context_->device_id()));
140  }
141  CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f);
142  }
143 
144  protected:
145  // Pointer to an external cuda context that the miopen wrapper will use.
146  HIPContext* context_;
147 
148  static constexpr size_t CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES = 4;
149 
151  {
152  std::mutex mutex;
153  std::unique_ptr<MIOPENState> state;
154  };
155 
156  using PerGPUMIOPENStates = std::array<
157  std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
158  C10_COMPILE_TIME_MAX_GPUS>;
159  static PerGPUMIOPENStates& miopen_states();
160 
161  C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);
162 };
163 
164 }; // namespace caffe2
165 
166 #endif
MIOPENWorkspace is a wrapper around a raw cuda pointer that holds the miopen scratch space...
miopenHandle_t inline_miopen_handle()
Returns the inline miopen handle that executes on the current thread&#39;s hip_stream.
MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
MIOPENWrapper(HIPContext *context)
Creates a miopen wrapper associated with a HIPContext object.