Caffe2 - C++ API
A deep learning, cross platform ML framework
common_rtc.h
1 #ifndef CAFFE2_CUDA_RTC_COMMON_RTC_H_
2 #define CAFFE2_CUDA_RTC_COMMON_RTC_H_
3 
4 #include <sstream>
5 #include <string>
6 
7 #include <cuda.h>
8 #include <nvrtc.h>
9 
10 #define NVRTC_CHECK(condition) \
11  do { \
12  nvrtcResult result = condition; \
13  if (result != NVRTC_SUCCESS) { \
14  LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
15  << nvrtcGetErrorString(result); \
16  } \
17  } while(0)
18 
19 namespace caffe2 {
20 
21 template <typename Derived>
23  public:
24  CudaRTCFunction() : module_loaded_(false) {}
25  ~CudaRTCFunction() {
26  if (module_loaded_) {
27  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
28  }
29  }
30 
31  // TODO: this function is nontrivial and since CudaRTCFunction uses CRTP, it
32  // may potentially increase the binary size. In that case, move common parts
33  // into a separate function.
34  template <typename... Args>
35  void Compile(Args... args) {
36  string src = static_cast<Derived*>(this)->GetSource(args...);
37  string name = static_cast<Derived*>(this)->KernelName(args...);
38  VLOG(1) << "function name: " << name;
39  VLOG(1) << "function src:\n" << src;
40  // Actually do the compiling.
41  nvrtcProgram prog;
42  NVRTC_CHECK(nvrtcCreateProgram(
43  &prog, src.c_str(), nullptr, 0, nullptr, nullptr));
44  // Compile the program.
45  // TODO(Yangqing): how to find the current gpu architecture instead of hard
46  // coding it?
47  const char *nvrtc_opts[] = {"--gpu-architecture=compute_35",
48  "--use_fast_math"};
49  nvrtcResult compile_result = nvrtcCompileProgram(
50  prog, 2, nvrtc_opts);
51  if (compile_result != NVRTC_SUCCESS) {
52  size_t log_size;
53  NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size));
54  vector<char> nvrtc_log(log_size);
55  NVRTC_CHECK(nvrtcGetProgramLog(prog, nvrtc_log.data()));
56  LOG(FATAL) << "Compilation failure for nvrtc("
57  << nvrtcGetErrorString(compile_result) << "): \n"
58  << nvrtc_log.data();
59  }
60  size_t ptx_size;
61  NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size));
62  vector<char> nvrtc_ptx(ptx_size);
63  NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data()));
64  NVRTC_CHECK(nvrtcDestroyProgram(&prog));
65  // After compilation, load the module.
66  if (module_loaded_) {
67  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
68  }
69  CUDA_DRIVERAPI_ENFORCE(
70  cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0));
71  module_loaded_ = true;
72  CUDA_DRIVERAPI_ENFORCE(
73  cuModuleGetFunction(&kernel_, module_, name.c_str()));
74  }
75 
76  template <typename... Args>
77  void Launch(unsigned int gx, unsigned int gy, unsigned int gz,
78  unsigned int bx, unsigned int by, unsigned int bz,
79  unsigned int shared_mem, cudaStream_t stream,
80  Args... args) {
81  CAFFE_ENFORCE(
82  module_loaded_, "Cannot call Launch before a module is loaded.");
83  void * args_voidp[] = {&args...};
84  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
85  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0));
86  }
87 
88  void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz,
89  unsigned int bx, unsigned int by, unsigned int bz,
90  unsigned int shared_mem, cudaStream_t stream,
91  void** extra) {
92  CAFFE_ENFORCE(
93  module_loaded_, "Cannot call Launch before a module is loaded.");
94  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
95  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, nullptr, extra));
96  }
97 
98  private:
99  bool module_loaded_;
100  CUmodule module_;
101  CUfunction kernel_;
102 };
103 
104 // TODO: this is in no way unique and is just a hack right now.
105 inline std::string GetUniqueName() {
106  static constexpr int len = 20;
107  static const char alpha[] =
108  "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
109 
110  std::stringstream ss;
111  ss << "_cuda_kernel_";
112  for (int i = 0; i < len; ++i) {
113  ss << alpha[rand() % (sizeof(alpha) - 1)];
114  }
115  return ss.str();
116 }
117 
118 } // namepsace caffe2
119 
120 #endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13