Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_kernel.cpp
1 #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
2 #include <torch/csrc/jit/fuser/compiler.h>
3 
4 #include <ATen/cuda/CUDAContext.h>
5 #include <THC/THC.h>
6 #include <c10/cuda/CUDAGuard.h>
7 #include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
8 #include <torch/csrc/jit/fuser/cuda/thnvrtc.h>
9 #include <torch/csrc/jit/resource_guard.h>
10 
11 // Note: unclear why this forward declaration is necessary
12 #include <THC/THCGenerator.hpp>
13 #include <THC/THCTensorRandom.h>
14 THCGenerator* THCRandom_getGenerator(THCState* state);
15 
16 #include <cuda_runtime.h>
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <sstream>
21 #include <stdexcept>
22 #include <tuple>
23 #include <vector>
24 
25 namespace torch {
26 namespace jit {
27 namespace fuser {
28 namespace cuda {
29 
30 // [USE OF NVRTC AND DRIVER API]
31 // libtorch does not directly link to either libnvrtc or libcuda because
32 // they require libcuda to be installed. Normal CUDA code in torch uses the cuda
33 // runtime libraries which can be installed even if the driver is not installed,
34 // but here we specifically need to use the driver API to load JIT compiled
35 // code. To accomplish this, we lazily link libthnvrtc which provides a struct
36 // THNVRTC that contains function pointers to all of the apis we need.
37 //
38 // IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY.
39 // INSTEAD USE, e.g. nvrtc().cuLoadModule(...)
40 // If a function is missing add it to the list in thnvrtc.
41 
42 void checkCUDAVersion(const cudaDeviceProp& prop) {
43  if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
44  (prop.major >= 7 && CUDA_VERSION < 9000)) {
45  std::stringstream err_string;
46  err_string
47  << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
48  << CUDA_VERSION << " for the current GPU device " << prop.name
49  << " with device capability " << prop.major << "." << prop.minor;
50  throw std::runtime_error(err_string.str());
51  }
52 }
53 
54 #ifdef USE_DIRECT_NVRTC
55 std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
56  return std::make_pair(nullptr, torch_load_nvrtc());
57 }
58 #else
59 std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
60  std::string path = cpu::DynamicLibrary::directoryOf((void*)checkCUDAVersion);
61 #ifdef __APPLE__
62  std::string libthnvrtc = path + "/libthnvrtc.dylib";
63 #else
64  std::string libthnvrtc = path + "/libthnvrtc.so";
65 #endif
66  std::unique_ptr<cpu::DynamicLibrary> libnvrtc_stub(
67  new cpu::DynamicLibrary(libthnvrtc.c_str()));
68  auto fn = (THNVRTC * (*)()) libnvrtc_stub->sym("torch_load_nvrtc");
69  return std::make_pair(std::move(libnvrtc_stub), fn());
70 }
71 #endif
72 
73 const THNVRTC& nvrtc() {
74  // must hold onto DynamicLibrary otherwise it will unload
75  static auto handle = loadNVRTC();
76  return *handle.second;
77 }
78 
79 // We're using three CUDA APIs, so define a few helpers for error handling
80 static inline void nvrtcCheck(nvrtcResult result, const char* file, int line) {
81  if (result != NVRTC_SUCCESS) {
82  std::stringstream ss;
83  ss << file << ":" << line << ": " << nvrtc().nvrtcGetErrorString(result);
84  throw std::runtime_error(ss.str());
85  }
86 }
87 #define TORCH_NVRTC_CHECK(result) nvrtcCheck(result, __FILE__, __LINE__);
88 
89 static inline void cuCheck(CUresult result, const char* file, int line) {
90  if (result != CUDA_SUCCESS) {
91  const char* str;
92  nvrtc().cuGetErrorString(result, &str);
93  std::stringstream ss;
94  ss << file << ":" << line << ": " << str;
95  throw std::runtime_error(ss.str());
96  }
97 }
98 #define TORCH_CU_CHECK(result) cuCheck(result, __FILE__, __LINE__);
99 
100 static void getMajorMinor(
101  const cudaDeviceProp* const prop,
102  int& major,
103  int& minor) {
104  int nvrtc_major, nvrtc_minor;
105  TORCH_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
106 
107  // Short-circuits if NVRTC version too low
108  AT_ASSERT(nvrtc_major >= 6);
109 
110  // Major and minor is determined by device properties and
111  // possibly "downcompiled" to a lower (compatible) compute architecture
112  // based on the NVRTC version
113  major = prop->major;
114  minor = prop->minor;
115  if (nvrtc_major <= 7 && prop->major > 5) { // 7 supports 2-5.x
116  major = 5;
117  minor = 0;
118  } else if (nvrtc_major <= 8 && prop->major > 6) { // 8 supports 2-6.x
119  major = 6;
120  minor = 0;
121  } else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2
122  major = 7;
123  if (prop->major == 7 && prop->minor <= 2)
124  minor = prop->minor;
125  else
126  minor = 0;
127  } else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5
128  major = 7;
129  if (prop->major == 7 && prop->minor <= 5)
130  minor = prop->minor;
131  else
132  minor = 0;
133  }
134 }
135 
136 // Compiles the specified kernel and stores the metadata required to run it
137 FusedKernelCUDA::FusedKernelCUDA(
138  int16_t device,
139  std::string name,
140  std::string code,
141  std::vector<TensorDesc> input_desc,
142  std::vector<TensorDesc> output_desc,
143  std::vector<PartitionDesc> chunk_desc,
144  std::vector<PartitionDesc> concat_desc,
145  bool has_random)
146  : FusedKernel(
147  std::move(name),
148  std::move(code),
149  std::move(input_desc),
150  std::move(output_desc),
151  std::move(chunk_desc),
152  std::move(concat_desc),
153  has_random),
154  device_(device) {
155  // Initializes driver's API context (if necessary)
156  CUcontext pctx = 0;
157  TORCH_CU_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
158  if (!pctx) {
159  std::unique_lock<std::mutex> cudaFreeMutexLock(
160  *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
161  cudaFree(0);
162  }
163 
164  // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
165  // properly in some scenarios
166  const auto prior_device = at::cuda::current_device();
167  at::cuda::set_device(device_);
168 
169  // Acquires device and NVRTC properties (for compile arch and occupancy
170  // calculations)
171  prop_ = at::cuda::getCurrentDeviceProperties();
172  int major, minor;
173  getMajorMinor(prop_, major, minor);
174 
175  // Creates the NVRTC program
176  nvrtcProgram program;
177  TORCH_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
178  &program, code_.c_str(), nullptr, 0, nullptr, nullptr));
179 
180  const std::string compute = "--gpu-architecture=compute_" +
181  std::to_string(major) + std::to_string(minor);
182  const std::vector<const char*> args = {
183  "--std=c++11", compute.c_str(), "-default-device"};
184  const auto result =
185  nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
186  if (result == NVRTC_ERROR_COMPILATION) {
187  size_t logsize;
188  nvrtcGetProgramLogSize(program, &logsize);
189  std::vector<char> log(logsize);
190  nvrtcGetProgramLog(program, log.data());
191  std::stringstream cu;
192  cu << log.data();
193  throw std::runtime_error(cu.str());
194  }
195  ResourceGuard holdProgram(
196  [&] { TORCH_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
197  TORCH_NVRTC_CHECK(result);
198  size_t ptx_size;
199  TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
200  ptx_.resize(ptx_size);
201  TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data()));
202 
203  TORCH_CU_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
204  TORCH_CU_CHECK(
205  nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
206 
207  // Computes max blocks
208  TORCH_CU_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
209  &maxBlocks_, function_, 128, 0));
210  maxBlocks_ *= prop_->multiProcessorCount;
211 
212  // Resets device (end of hacked at::DeviceGuard)
213  at::cuda::set_device(prior_device);
214 }
215 
216 static int ceilDiv(const int a, const int b) {
217  return (a + b - 1) / b;
218 }
219 
220 void FusedKernelCUDA::launch_raw(
221  const uint32_t numel,
222  std::vector<void*>& arguments) const {
223  at::cuda::CUDAGuard{device_};
224  // Hacked at::DeviceGuard (see note above)
225  const auto prior_device = at::cuda::current_device();
226  at::cuda::set_device(device_);
227 
228  const auto nBlocks = std::min(maxBlocks_, ceilDiv(numel, kBlockSize));
229 
230  // Adds random state to arguments if necessary
231  // Note: offset defined here so its lifetime extends to the launch
232  uint64_t offset;
233  if (has_random_) {
234  const auto rand_offset =
235  4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
236  auto gen = THCRandom_getGenerator(at::globalContext().getTHCState());
237  offset = gen->state.philox_seed_offset.fetch_add(rand_offset);
238  arguments.push_back(&gen->state.initial_seed);
239  arguments.push_back(&offset);
240  }
241 
242  // Launches kernel on current stream (device was set by executor)
243  auto stream = at::cuda::getCurrentCUDAStream();
244  TORCH_CU_CHECK(nvrtc().cuLaunchKernel(
245  function_,
246  nBlocks,
247  1,
248  1,
249  kBlockSize,
250  1,
251  1,
252  0,
253  stream,
254  arguments.data(),
255  nullptr));
256 
257  // Resets device (see at::DeviceGuard notes above)
258  at::cuda::set_device(prior_device);
259 }
260 
261 FusedKernelCUDA::~FusedKernelCUDA() {
262  nvrtc().cuModuleUnload(module_);
263 }
264 
265 static std::shared_ptr<FusedKernel> createFusionKernel(
266  int16_t device,
267  std::string name,
268  std::string code,
269  std::vector<TensorDesc> input_desc,
270  std::vector<TensorDesc> output_desc,
271  std::vector<PartitionDesc> chunk_desc,
272  std::vector<PartitionDesc> concat_desc,
273  bool has_random) {
274  return std::make_shared<FusedKernelCUDA>(
275  device,
276  std::move(name),
277  std::move(code),
278  std::move(input_desc),
279  std::move(output_desc),
280  std::move(chunk_desc),
281  std::move(concat_desc),
282  has_random);
283 }
284 
285 RegisterFusionBackend reg(at::DeviceType::CUDA, createFusionKernel);
286 
287 } // namespace cuda
288 } // namespace fuser
289 } // namespace jit
290 } // namespace torch
Definition: jit_type.h:17
A variant of DeviceGuard that is specialized for CUDA.
Definition: CUDAGuard.h:20