3 from .utils
import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend
11 def __getattr__(self, name):
14 def __getitem__(self, name):
20 def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()):
46 def library_state(self):
47 return torch.cuda._state_cdata
52 _thnn_headers = parse_header(THNN_H_PATH)
53 _thcunn_headers = parse_header(THCUNN_H_PATH)
55 for t
in [
'Float',
'Double']:
56 backend =
Backend(t,
'_THNN', _thnn_headers)
58 type2backend.backends[
'THNN{}Backend'.format(t)] = backend
59 type2backend.backends[
'torch.{}Tensor'.format(t)] = backend
60 type2backend.backends[getattr(torch,
'{}Tensor'.format(t))] = backend
63 for t
in [
'Half',
'',
'Double']:
64 backend =
Backend(
'Cuda' + t,
'_THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,))
65 type2backend.backends[
'THNNCuda{}Backend'.format(t)] = backend
66 py_name =
'Float' if t ==
'' else t
67 type2backend.backends[
'torch.cuda.{}Tensor'.format(py_name)] = backend
68 type2backend.backends[getattr(
torch.cuda,
'{}Tensor'.format(py_name))] = backend