7 from torch.version
import cuda
8 from contextlib
import contextmanager
9 from subprocess
import Popen, PIPE
18 __cudnn_version =
None 26 __allow_nonbracketed_mutation_flag =
True 29 def find_cudnn_windows_lib():
30 proc = Popen([
'where',
'cudnn64*.dll'], stdout=PIPE, stderr=PIPE, stdin=PIPE)
31 out, err = proc.communicate()
32 out = out.decode().strip()
34 if out.find(
'\r\n') != -1:
35 out = out.split(
'\r\n')[0]
36 cudnn_lib_name = os.path.basename(out)
37 cudnn_lib = os.path.splitext(cudnn_lib_name)[0]
38 cudnn_lib = str(cudnn_lib)
39 return ctypes.cdll.LoadLibrary(cudnn_lib)
45 global lib, __cudnn_version
47 if sys.platform ==
"win32":
48 lib = find_cudnn_windows_lib()
50 lib = ctypes.cdll.LoadLibrary(
None)
51 if hasattr(lib,
'cudnnGetErrorString'):
52 lib.cudnnGetErrorString.restype = ctypes.c_char_p
53 __cudnn_version = lib.cudnnGetVersion()
54 compile_version = torch._C._cudnn_version()
56 runtime_major = __cudnn_version // 1000
57 runtime_minor = (__cudnn_version % 1000) // 100
58 compile_major = compile_version // 1000
59 compile_minor = (compile_version % 1000) // 100
62 if runtime_major != compile_major:
63 cudnn_compatible =
False 64 elif runtime_major < 7:
65 cudnn_compatible = runtime_minor == compile_minor
67 cudnn_compatible = runtime_minor >= compile_minor
68 if not cudnn_compatible:
70 'cuDNN version incompatibility: PyTorch was compiled against {} ' 71 'but linked against {}'.format(compile_version, __cudnn_version))
78 if _libcudnn()
is None:
80 return __cudnn_version
83 CUDNN_TENSOR_TYPES = {
84 'torch.cuda.HalfTensor',
85 'torch.cuda.FloatTensor',
86 'torch.cuda.DoubleTensor',
91 r"""Returns a bool indicating if CUDNN is currently available.""" 92 return torch._C.has_cudnn
95 def is_acceptable(tensor):
96 if not torch._C._get_cudnn_enabled():
98 if tensor.type()
not in CUDNN_TENSOR_TYPES:
100 if not is_available():
102 "PyTorch was compiled without cuDNN support. To use cuDNN, rebuild " 103 "PyTorch making sure the library is visible to the build system.")
105 if _libcudnn()
is None:
106 warnings.warn(
'cuDNN library not found. Check your {libpath}'.format(
108 'darwin':
'DYLD_LIBRARY_PATH',
110 }.get(sys.platform,
'LD_LIBRARY_PATH')))
120 CUDNN_DATA_DOUBLE = 1
123 CUDNN_TENSOR_NCHW = 0
124 CUDNN_TENSOR_NHWC = 1
131 CUDNN_LINEAR_INPUT = 0
134 CUDNN_RNN_ALGO_STANDARD = 0
135 CUDNN_RNN_ALGO_PERSIST_STATIC = 1
136 CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2
138 CUDNN_DEFAULT_MATH = 0
139 CUDNN_TENSOR_OP_MATH = 1
142 def set_flags(_enabled, _benchmark, _deterministic, _verbose):
143 global benchmark, deterministic, verbose
144 orig_flags = (torch._C._get_cudnn_enabled(),
145 torch._C._get_cudnn_benchmark(),
146 torch._C._get_cudnn_deterministic(),
149 torch._C._set_cudnn_enabled(_enabled)
150 torch._C._set_cudnn_benchmark(_benchmark)
151 torch._C._set_cudnn_deterministic(_deterministic)
155 def disable_global_flags():
156 global __allow_nonbracketed_mutation_flag
157 __allow_nonbracketed_mutation_flag =
False 161 return not __allow_nonbracketed_mutation_flag
165 def __allow_nonbracketed_mutation():
166 global __allow_nonbracketed_mutation_flag
167 old = __allow_nonbracketed_mutation_flag
168 __allow_nonbracketed_mutation_flag =
True 172 __allow_nonbracketed_mutation_flag = old
176 def flags(enabled=False, benchmark=False, deterministic=False, verbose=False):
177 with __allow_nonbracketed_mutation():
178 orig_flags = set_flags(enabled, benchmark, deterministic, verbose)
183 with __allow_nonbracketed_mutation():
184 set_flags(orig_flags[0], orig_flags[1], orig_flags[2], orig_flags[3])
189 ptr = ctypes.c_void_p()
190 check_error(lib.cudnnCreate(ctypes.byref(ptr)))
194 check_error(lib.cudnnDestroy(self))
198 def __init__(self, status):
200 msg =
'{}: {}'.format(status, get_error_string(status))
201 super(CuDNNError, self).__init__(msg)
206 ptr = ctypes.c_void_p()
207 check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr)))
211 check_error(lib.cudnnDestroyTensorDescriptor(self.
_as_parameter_))
214 def set(self, tensor):
215 self.
_type = tensor.type()
216 self.
_size = tensor.size()
218 check_error(lib.cudnnSetTensorNdDescriptor(
219 self, _typemap[tensor.type()], tensor.dim(),
220 int_array(tensor.size()), int_array(tensor.stride())))
227 def __init__(self, N):
228 self.
ptrs = (ctypes.c_void_p * N)()
230 ptr = ctypes.byref(self.
ptrs, i * ctypes.sizeof(ctypes.c_void_p))
231 check_error(lib.cudnnCreateTensorDescriptor(ptr))
235 for ptr
in self.
ptrs:
236 check_error(lib.cudnnDestroyTensorDescriptor(ctypes.c_void_p(ptr)))
238 def __getitem__(self, key):
239 return ctypes.c_void_p(self.
ptrs[key])
241 def set_all(self, tensor):
242 _type = _typemap[tensor.type()]
244 _size = int_array(tensor.size())
245 _stride = int_array(tensor.stride())
246 for ptr
in self.
ptrs:
247 check_error(lib.cudnnSetTensorNdDescriptor(
248 ctypes.c_void_p(ptr), _type, _ndim, _size, _stride))
250 def set_raw(self, i, _type, _ndim, _size, _stride):
252 check_error(lib.cudnnSetTensorNdDescriptor(
253 ctypes.c_void_p(ptr), _type, _ndim, _size, _stride))
258 ptr = ctypes.c_void_p()
259 check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr)))
263 check_error(lib.cudnnDestroyFilterDescriptor(self.
_as_parameter_))
266 def set(self, weight):
267 self.
_size = weight.size()
268 datatype = _typemap[weight.type()]
269 check_error(lib.cudnnSetFilterNdDescriptor(
270 self, datatype, CUDNN_TENSOR_NCHW, weight.ndimension(),
271 int_array(weight.size())))
274 return tuple(self.
_size)
278 def __init__(self, handle, dropout, seed):
279 ptr = ctypes.c_void_p()
280 check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
287 self.
_set(dropout, seed)
289 def set_dropout(self, dropout, seed):
291 self.
_set(dropout, seed)
293 def _set(self, dropout, seed):
294 if self.
state is None and dropout > 0:
295 dropout_states_size = ctypes.c_long()
296 check_error(lib.cudnnDropoutGetStatesSize(
298 ctypes.byref(dropout_states_size)))
299 self.
state = torch.cuda.ByteTensor(dropout_states_size.value)
300 state_ptr = self.state.data_ptr()
301 state_size = self.state.size(0)
306 check_error(lib.cudnnSetDropoutDescriptor(
309 ctypes.c_float(dropout),
310 ctypes.c_void_p(state_ptr),
311 ctypes.c_size_t(state_size),
312 ctypes.c_ulonglong(seed),
318 check_error(lib.cudnnDestroyDropoutDescriptor(self))
322 def __init__(self, handle, hidden_size, num_layers, dropout_desc, input_mode,
323 bidirectional, mode, datatype):
324 ptr = ctypes.c_void_p()
325 check_error(lib.cudnnCreateRNNDescriptor(ctypes.byref(ptr)))
327 if version() >= 6000:
328 check_error(lib.cudnnSetRNNDescriptor_v6(
337 CUDNN_RNN_ALGO_STANDARD,
340 if version() >= 7000
and int(cuda[0]) >= 9
and (
342 lib.cudnnSetRNNMatrixMathType(self, CUDNN_DEFAULT_MATH)
343 if datatype == CUDNN_DATA_HALF:
344 lib.cudnnSetRNNMatrixMathType(self, CUDNN_TENSOR_OP_MATH)
346 check_error(lib.cudnnSetRNNDescriptor(
358 check_error(lib.cudnnDestroyRNNDescriptor(self))
361 def check_error(status):
366 def get_error_string(status):
367 return lib.cudnnGetErrorString(status)
371 if _libcudnn()
is None:
372 raise RuntimeError(
'cuDNN not available')
374 handle = _handles.get(current_device,
None)
377 _handles[current_device] = handle
382 'torch.cuda.HalfTensor': CUDNN_DATA_HALF,
383 'torch.cuda.FloatTensor': CUDNN_DATA_FLOAT,
384 'torch.cuda.DoubleTensor': CUDNN_DATA_DOUBLE,
390 CUDNN_DATA_DOUBLE: 8,
395 if isinstance(tensor, torch.cuda.HalfTensor):
396 return ctypes.c_float
397 elif isinstance(tensor, torch.cuda.FloatTensor):
398 return ctypes.c_float
399 elif isinstance(tensor, torch.cuda.DoubleTensor):
400 return ctypes.c_double
402 raise ValueError(
"unknown type '{}'".format(type(tensor)))
406 array_type = ctypes.c_int * len(itr)
407 return array_type(*itr)
410 def descriptor(tensor, N=None):
411 padded_size = tensor.size() + ((1,) * (5 - tensor.dim()))
412 tensor = tensor.view(padded_size)
415 descriptor.set_all(tensor)
418 descriptor.set(tensor)
422 def descriptor_sequence(tensor, batch_sizes):
424 _type = _typemap[tensor.type()]
426 dim_pad = (1,) * (5 - tensor.dim())
427 _size = int_array(tensor.size() + dim_pad)
428 _stride = int_array(tensor.stride() + dim_pad)
429 for i, batch_size
in enumerate(batch_sizes):
430 _size[0] = batch_size
431 descriptors.set_raw(i, _type, _ndim, _size, _stride)
435 def add_tensor(*args):
436 check_error(lib.cudnnAddTensor(*args))
444 def __init__(self, getter, setter):
448 def __get__(self, obj, objtype):
451 def __set__(self, obj, val):
452 if not flags_frozen():
455 raise RuntimeError(
"not allowed to set torch.backends.cudnn flags " 456 "after disable_global_flags; please use flags() context manager instead")
460 def __init__(self, m, name):
461 super(CudnnModule, self).__init__(name)
464 def __getattr__(self, attr):
465 return self.m.__getattribute__(attr)
467 enabled =
ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
468 deterministic =
ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
469 benchmark =
ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
473 sys.modules[__name__] =
CudnnModule(sys.modules[__name__], __name__)
def _set(self, dropout, seed)
def get_device_capability(device=None)