Caffe2 - C++ API
A deep learning, cross platform ML framework
common_rtc.h
1 
17 #ifndef CAFFE2_CUDA_RTC_COMMON_RTC_H_
18 #define CAFFE2_CUDA_RTC_COMMON_RTC_H_
19 
20 #include <sstream>
21 #include <string>
22 
23 #include <cuda.h>
24 #include <nvrtc.h>
25 
26 #define NVRTC_CHECK(condition) \
27  do { \
28  nvrtcResult result = condition; \
29  if (result != NVRTC_SUCCESS) { \
30  LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
31  << nvrtcGetErrorString(result); \
32  } \
33  } while(0)
34 
35 namespace caffe2 {
36 
37 template <typename Derived>
39  public:
40  CudaRTCFunction() : module_loaded_(false) {}
41  ~CudaRTCFunction() {
42  if (module_loaded_) {
43  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
44  }
45  }
46 
47  // TODO: this function is nontrivial and since CudaRTCFunction uses CRTP, it
48  // may potentially increase the binary size. In that case, move common parts
49  // into a separate function.
50  template <typename... Args>
51  void Compile(Args... args) {
52  string src = static_cast<Derived*>(this)->GetSource(args...);
53  string name = static_cast<Derived*>(this)->KernelName(args...);
54  VLOG(1) << "function name: " << name;
55  VLOG(1) << "function src:\n" << src;
56  // Actually do the compiling.
57  nvrtcProgram prog;
58  NVRTC_CHECK(nvrtcCreateProgram(
59  &prog, src.c_str(), nullptr, 0, nullptr, nullptr));
60  // Compile the program.
61  // TODO(Yangqing): how to find the current gpu architecture instead of hard
62  // coding it?
63  const char *nvrtc_opts[] = {"--gpu-architecture=compute_35",
64  "--use_fast_math"};
65  nvrtcResult compile_result = nvrtcCompileProgram(
66  prog, 2, nvrtc_opts);
67  if (compile_result != NVRTC_SUCCESS) {
68  size_t log_size;
69  NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size));
70  vector<char> nvrtc_log(log_size);
71  NVRTC_CHECK(nvrtcGetProgramLog(prog, nvrtc_log.data()));
72  LOG(FATAL) << "Compilation failure for nvrtc("
73  << nvrtcGetErrorString(compile_result) << "): \n"
74  << nvrtc_log.data();
75  }
76  size_t ptx_size;
77  NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size));
78  vector<char> nvrtc_ptx(ptx_size);
79  NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data()));
80  NVRTC_CHECK(nvrtcDestroyProgram(&prog));
81  // After compilation, load the module.
82  if (module_loaded_) {
83  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
84  }
85  CUDA_DRIVERAPI_ENFORCE(
86  cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0));
87  module_loaded_ = true;
88  CUDA_DRIVERAPI_ENFORCE(
89  cuModuleGetFunction(&kernel_, module_, name.c_str()));
90  }
91 
92  template <typename... Args>
93  void Launch(unsigned int gx, unsigned int gy, unsigned int gz,
94  unsigned int bx, unsigned int by, unsigned int bz,
95  unsigned int shared_mem, cudaStream_t stream,
96  Args... args) {
97  CAFFE_ENFORCE(
98  module_loaded_, "Cannot call Launch before a module is loaded.");
99  void * args_voidp[] = {&args...};
100  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
101  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0));
102  }
103 
104  void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz,
105  unsigned int bx, unsigned int by, unsigned int bz,
106  unsigned int shared_mem, cudaStream_t stream,
107  void** extra) {
108  CAFFE_ENFORCE(
109  module_loaded_, "Cannot call Launch before a module is loaded.");
110  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
111  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, nullptr, extra));
112  }
113 
114  private:
115  bool module_loaded_;
116  CUmodule module_;
117  CUfunction kernel_;
118 };
119 
120 // TODO: this is in no way unique and is just a hack right now.
121 inline std::string GetUniqueName() {
122  static constexpr int len = 20;
123  static const char alpha[] =
124  "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
125 
126  std::stringstream ss;
127  ss << "_cuda_kernel_";
128  for (int i = 0; i < len; ++i) {
129  ss << alpha[rand() % (sizeof(alpha) - 1)];
130  }
131  return ss.str();
132 }
133 
134 } // namepsace caffe2
135 
136 #endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_
Copyright (c) 2016-present, Facebook, Inc.