1 #include <torch/csrc/jit/fuser/cuda/fused_kernel.h> 2 #include <torch/csrc/jit/fuser/compiler.h> 4 #include <ATen/cuda/CUDAContext.h> 6 #include <c10/cuda/CUDAGuard.h> 7 #include <torch/csrc/jit/fuser/cpu/dynamic_library.h> 8 #include <torch/csrc/jit/fuser/cuda/thnvrtc.h> 9 #include <torch/csrc/jit/resource_guard.h> 12 #include <THC/THCGenerator.hpp> 13 #include <THC/THCTensorRandom.h> 14 THCGenerator* THCRandom_getGenerator(THCState* state);
16 #include <cuda_runtime.h> 42 void checkCUDAVersion(
const cudaDeviceProp& prop) {
43 if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
44 (prop.major >= 7 && CUDA_VERSION < 9000)) {
45 std::stringstream err_string;
47 <<
"In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: " 48 << CUDA_VERSION <<
" for the current GPU device " << prop.name
49 <<
" with device capability " << prop.major <<
"." << prop.minor;
50 throw std::runtime_error(err_string.str());
54 #ifdef USE_DIRECT_NVRTC 55 std::pair<std::unique_ptr<cpu::DynamicLibrary>,
THNVRTC*> loadNVRTC() {
56 return std::make_pair(
nullptr, torch_load_nvrtc());
59 std::pair<std::unique_ptr<cpu::DynamicLibrary>,
THNVRTC*> loadNVRTC() {
60 std::string path = cpu::DynamicLibrary::directoryOf((
void*)checkCUDAVersion);
62 std::string libthnvrtc = path +
"/libthnvrtc.dylib";
64 std::string libthnvrtc = path +
"/libthnvrtc.so";
66 std::unique_ptr<cpu::DynamicLibrary> libnvrtc_stub(
67 new cpu::DynamicLibrary(libthnvrtc.c_str()));
68 auto fn = (
THNVRTC * (*)()) libnvrtc_stub->sym(
"torch_load_nvrtc");
69 return std::make_pair(std::move(libnvrtc_stub), fn());
75 static auto handle = loadNVRTC();
76 return *handle.second;
80 static inline void nvrtcCheck(nvrtcResult result,
const char* file,
int line) {
81 if (result != NVRTC_SUCCESS) {
83 ss << file <<
":" << line <<
": " << nvrtc().nvrtcGetErrorString(result);
84 throw std::runtime_error(ss.str());
87 #define TORCH_NVRTC_CHECK(result) nvrtcCheck(result, __FILE__, __LINE__); 89 static inline void cuCheck(CUresult result,
const char* file,
int line) {
90 if (result != CUDA_SUCCESS) {
92 nvrtc().cuGetErrorString(result, &str);
94 ss << file <<
":" << line <<
": " << str;
95 throw std::runtime_error(ss.str());
98 #define TORCH_CU_CHECK(result) cuCheck(result, __FILE__, __LINE__); 100 static void getMajorMinor(
101 const cudaDeviceProp*
const prop,
104 int nvrtc_major, nvrtc_minor;
105 TORCH_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
108 AT_ASSERT(nvrtc_major >= 6);
115 if (nvrtc_major <= 7 && prop->major > 5) {
118 }
else if (nvrtc_major <= 8 && prop->major > 6) {
121 }
else if (nvrtc_major <= 9 && prop->major >= 7) {
123 if (prop->major == 7 && prop->minor <= 2)
127 }
else if (nvrtc_major <= 10 && prop->major >= 7) {
129 if (prop->major == 7 && prop->minor <= 5)
137 FusedKernelCUDA::FusedKernelCUDA(
141 std::vector<TensorDesc> input_desc,
142 std::vector<TensorDesc> output_desc,
143 std::vector<PartitionDesc> chunk_desc,
144 std::vector<PartitionDesc> concat_desc,
149 std::move(input_desc),
150 std::move(output_desc),
151 std::move(chunk_desc),
152 std::move(concat_desc),
157 TORCH_CU_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
159 std::unique_lock<std::mutex> cudaFreeMutexLock(
160 *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
166 const auto prior_device = at::cuda::current_device();
167 at::cuda::set_device(device_);
171 prop_ = at::cuda::getCurrentDeviceProperties();
173 getMajorMinor(prop_, major, minor);
176 nvrtcProgram program;
177 TORCH_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
178 &program, code_.c_str(),
nullptr, 0,
nullptr,
nullptr));
180 const std::string compute =
"--gpu-architecture=compute_" +
181 std::to_string(major) + std::to_string(minor);
182 const std::vector<const char*> args = {
183 "--std=c++11", compute.c_str(),
"-default-device"};
185 nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
186 if (result == NVRTC_ERROR_COMPILATION) {
188 nvrtcGetProgramLogSize(program, &logsize);
189 std::vector<char> log(logsize);
190 nvrtcGetProgramLog(program, log.data());
191 std::stringstream cu;
193 throw std::runtime_error(cu.str());
195 ResourceGuard holdProgram(
196 [&] { TORCH_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
197 TORCH_NVRTC_CHECK(result);
199 TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
200 ptx_.resize(ptx_size);
201 TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data()));
203 TORCH_CU_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
205 nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
208 TORCH_CU_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
209 &maxBlocks_, function_, 128, 0));
210 maxBlocks_ *= prop_->multiProcessorCount;
213 at::cuda::set_device(prior_device);
216 static int ceilDiv(
const int a,
const int b) {
217 return (a + b - 1) / b;
220 void FusedKernelCUDA::launch_raw(
221 const uint32_t numel,
222 std::vector<void*>& arguments)
const {
225 const auto prior_device = at::cuda::current_device();
226 at::cuda::set_device(device_);
228 const auto nBlocks = std::min(maxBlocks_, ceilDiv(numel, kBlockSize));
234 const auto rand_offset =
235 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
236 auto gen = THCRandom_getGenerator(at::globalContext().getTHCState());
237 offset = gen->state.philox_seed_offset.fetch_add(rand_offset);
238 arguments.push_back(&gen->state.initial_seed);
239 arguments.push_back(&offset);
243 auto stream = at::cuda::getCurrentCUDAStream();
244 TORCH_CU_CHECK(nvrtc().cuLaunchKernel(
258 at::cuda::set_device(prior_device);
261 FusedKernelCUDA::~FusedKernelCUDA() {
262 nvrtc().cuModuleUnload(module_);
265 static std::shared_ptr<FusedKernel> createFusionKernel(
269 std::vector<TensorDesc> input_desc,
270 std::vector<TensorDesc> output_desc,
271 std::vector<PartitionDesc> chunk_desc,
272 std::vector<PartitionDesc> concat_desc,
274 return std::make_shared<FusedKernelCUDA>(
278 std::move(input_desc),
279 std::move(output_desc),
280 std::move(chunk_desc),
281 std::move(concat_desc),
285 RegisterFusionBackend reg(at::DeviceType::CUDA, createFusionKernel);
A variant of DeviceGuard that is specialized for CUDA.