Caffe2 - Python API
A deep learning, cross platform ML framework
__init__.py
1 import threading
2 import torch.cuda
3 from .utils import THNN_H_PATH, THCUNN_H_PATH, parse_header, load_backend
4 
5 
6 class Backends(object):
7 
8  def __init__(self):
9  self.backends = {}
10 
11  def __getattr__(self, name):
12  return self.backends[name].load()
13 
14  def __getitem__(self, name):
15  return self.backends[name].load()
16 
17 
18 class Backend(object):
19 
20  def __init__(self, lib_prefix, lib_name, functions, mixins=tuple()):
21  self.lib_prefix = lib_prefix
22  self.lib_name = lib_name
23  self.functions = functions
24  self.mixins = mixins
25  self.backend = None
26  self.loading_lock = threading.Lock()
27 
28  def load(self):
29  # This looks a little weird, but it's necessary for thread safe loading.
30  # Loading the backend can take some time, so multiple threads can enter
31  # the if clause. We have to ensure that only the first one to acquire
32  # the lock will actually load the backend, and that the rest won't
33  # do it again.
34  if self.backend is None:
35  with self.loading_lock:
36  if self.backend is None:
37  lib = getattr(torch._C, self.lib_name)
38  self.backend = load_backend(self.lib_prefix, lib,
39  self.functions, self.mixins)
40  return self.backend
41 
42 
44 
45  @property
46  def library_state(self):
47  return torch.cuda._state_cdata
48 
49 
50 type2backend = Backends()
51 
52 _thnn_headers = parse_header(THNN_H_PATH)
53 _thcunn_headers = parse_header(THCUNN_H_PATH)
54 
55 for t in ['Float', 'Double']:
56  backend = Backend(t, '_THNN', _thnn_headers)
57 
58  type2backend.backends['THNN{}Backend'.format(t)] = backend
59  type2backend.backends['torch.{}Tensor'.format(t)] = backend
60  type2backend.backends[getattr(torch, '{}Tensor'.format(t))] = backend
61 
62 
63 for t in ['Half', '', 'Double']:
64  backend = Backend('Cuda' + t, '_THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,))
65  type2backend.backends['THNNCuda{}Backend'.format(t)] = backend
66  py_name = 'Float' if t == '' else t
67  type2backend.backends['torch.cuda.{}Tensor'.format(py_name)] = backend
68  type2backend.backends[getattr(torch.cuda, '{}Tensor'.format(py_name))] = backend