1 #include <torch/csrc/autograd/profiler.h> 2 #include <torch/csrc/autograd/function.h> 7 namespace torch {
namespace autograd {
namespace profiler {
9 CUDAStubs default_stubs;
10 constexpr CUDAStubs* default_stubs_addr = &default_stubs;
13 static CUDAStubs* cuda_stubs = default_stubs_addr;
15 TORCH_API
void registerCUDAMethods(CUDAStubs* stubs) {
19 ProfilerState state = ProfilerState::Disabled;
20 uint16_t next_thread_id = 0;
21 std::mutex all_event_lists_mutex;
22 std::list<std::shared_ptr<RangeEventList>> all_event_lists;
23 thread_local std::shared_ptr<RangeEventList> event_list;
24 thread_local uint16_t thread_id;
26 RangeEventList& getEventList() {
28 std::lock_guard<std::mutex> guard(all_event_lists_mutex);
29 event_list = std::make_shared<RangeEventList>();
30 thread_id = next_thread_id++;
31 all_event_lists.emplace_front(event_list);
36 void mark(std::string name,
bool include_cuda ) {
37 if (state == ProfilerState::Disabled) {
40 if (state == ProfilerState::NVTX) {
41 cuda_stubs->nvtxMarkA(name.c_str());
43 getEventList().record(
47 include_cuda && state == ProfilerState::CUDA);
51 const char* c_str(
const char *str) {
return str; }
53 const char* c_str(std::string& str) {
return str.c_str(); }
56 void pushRangeImpl(
T name,
const char* msg=
"", int64_t sequence_nr=-1) {
57 if (state == ProfilerState::Disabled) {
60 if (state == ProfilerState::NVTX) {
61 if(sequence_nr >= 0) {
63 s << name << msg << sequence_nr;
64 cuda_stubs->nvtxRangePushA(s.str().c_str());
66 cuda_stubs->nvtxRangePushA(c_str(name));
69 getEventList().record(
73 state == ProfilerState::CUDA);
77 void pushRange(std::string name) {
78 pushRangeImpl(std::move(name));
82 if (state == ProfilerState::Disabled) {
85 if (state == ProfilerState::NVTX) {
86 cuda_stubs->nvtxRangePop();
88 getEventList().record(
92 state == ProfilerState::CUDA);
96 RecordFunction::RecordFunction(Function* fn) {
103 pushRangeImpl(fn->name(),
", stashed seq=", fn->sequence_nr());
106 RecordFunction::RecordFunction(std::string name) {
107 pushRangeImpl(std::move(name));
110 RecordFunction::RecordFunction(
const char* name) {
111 pushRangeImpl<const char*>(name);
114 RecordFunction::RecordFunction(
const char* name, int64_t current_sequence_nr)
116 pushRangeImpl<const char*>(name,
", seq=", current_sequence_nr);
119 void enableProfiler(ProfilerState new_state) {
120 AT_ASSERT(new_state != ProfilerState::Disabled);
121 if (new_state == ProfilerState::NVTX && !cuda_stubs->enabled())
122 throw std::runtime_error(
"Can't use NVTX profiler - PyTorch was compiled without CUDA");
123 if (state != ProfilerState::Disabled && new_state != state) {
124 throw std::runtime_error(
"can't change kind of profiling (e.g. NVTX to CPU) while profiler is running");
128 if(state == ProfilerState::CUDA) {
131 for(
int i = 0; i < 5; i++) {
132 cuda_stubs->onEachDevice([](
int d) {
133 mark(
"__cuda_startup");
134 cuda_stubs->synchronize();
141 cuda_stubs->onEachDevice([](
int d) {
142 mark(
"__cuda_start_event");
145 mark(
"__start_profile",
false);
148 thread_event_lists disableProfiler() {
149 if (state == ProfilerState::Disabled) {
150 throw std::runtime_error(
"can't disable profiler when it's not running");
152 ProfilerState old_state = state;
153 mark(
"__stop_profile");
154 state = ProfilerState::Disabled;
155 if (old_state == ProfilerState::NVTX) {
156 return thread_event_lists();
158 thread_event_lists result;
159 std::lock_guard<std::mutex> guard(all_event_lists_mutex);
160 for (
auto it = all_event_lists.begin(); it != all_event_lists.end();) {
162 result.emplace_back(list->consolidate());
164 if (list.use_count() == 1) {
165 auto current_it = it;
167 all_event_lists.erase(current_it);
176 void Event::record(
bool record_cuda) {
178 cuda_stubs->record(&device_, &event, &cpu_ns_);
184 double Event::cuda_elapsed_us(
const Event & e) {
185 if(!e.has_cuda() || !has_cuda()) {
186 throw std::logic_error(
"Events were not recorded for CUDA");
188 if(e.device() != device()) {
189 throw std::logic_error(
"Events are not on the same device");
191 return cuda_stubs->elapsed(event, e.event);
194 CUDAStubs::~CUDAStubs() =
default;
197 static jit::CodeTemplate event_template(R
"( 204 "pid": "CPU Functions", 209 RecordProfile::RecordProfile(std::ostream& out) 214 RecordProfile::RecordProfile(const std::string& filename)
215 : file_(new
std::ofstream(filename)), out_(*file_) {
219 void RecordProfile::init() {
220 enableProfiler(ProfilerState::CPU);
223 RecordProfile::~RecordProfile() {
224 thread_event_lists event_lists = disableProfiler();
225 std::vector<Event*> events;
226 for(
auto& l : event_lists) {
228 events.push_back(&e);
231 processEvents(events);
237 void RecordProfile::processEvents(
const std::vector<Event*>& events) {
238 AT_CHECK(out_,
"could not open file");
239 Event* start =
nullptr;
240 for (Event* e : events) {
241 if(0 == strcmp(e->name(),
"__start_profile")) {
246 AT_CHECK(start,
"could not find start?");
247 std::vector<Event*> stack;
250 for(Event* e : events) {
251 if(e->kind() ==
"push") {
253 }
else if(e->kind() ==
"pop") {
258 Event* e_start = stack.back();
260 jit::TemplateEnv env;
261 env.s(
"name", e_start->name());
262 env.d(
"ts", start->cpu_elapsed_us(*e_start));
263 env.d(
"dur", e_start->cpu_elapsed_us(*e));
264 env.d(
"tid", e_start->thread_id());
265 out_ << event_template.format(env);