1 #include <torch/csrc/autograd/profiler.h> 2 #include <c10/cuda/CUDAGuard.h> 3 #include <nvToolsExt.h> 7 namespace torch {
namespace autograd {
namespace profiler {
11 static inline void cudaCheck(cudaError_t result,
const char * file,
int line) {
12 if(result != cudaSuccess) {
14 ss << file <<
":" << line <<
": " << cudaGetErrorString(result);
15 throw std::runtime_error(ss.str());
18 #define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__); 20 struct CUDAMethods :
public CUDAStubs {
21 void record(
int* device, CUDAEventStub* event, int64_t* cpu_ns)
override {
22 TORCH_CUDA_CHECK(cudaGetDevice(device));
23 TORCH_CUDA_CHECK(cudaEventCreate(event));
24 auto stream = at::cuda::getCurrentCUDAStream();
26 TORCH_CUDA_CHECK(cudaEventRecord(*event, stream));
28 float elapsed(CUDAEventStub event, CUDAEventStub event2)
override {
29 TORCH_CUDA_CHECK(cudaEventSynchronize(event));
30 TORCH_CUDA_CHECK(cudaEventSynchronize(event2));
32 TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, event2));
35 void nvtxMarkA(
const char* name)
override {
38 void nvtxRangePushA(
const char* name)
override {
39 ::nvtxRangePushA(name);
41 void nvtxRangePop()
override {
44 void onEachDevice(std::function<
void(
int)> op)
override {
47 TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
48 for(
int i = 0; i < count; i++) {
53 void synchronize()
override {
54 cudaDeviceSynchronize();
56 bool enabled()
override {
62 struct RegisterCUDAMethods {
63 RegisterCUDAMethods() {
64 static CUDAMethods methods;
65 registerCUDAMethods(&methods);
68 RegisterCUDAMethods reg;
A variant of OptionalDeviceGuard that is specialized for CUDA.
void set_index(DeviceIndex device_index)
Sets the CUDA device to the given device index, initializing the guard if it is not already initializ...