Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 import os
2 import itertools
3 import importlib
4 
5 try:
6  # when compiling a cffi extension, this works. When compiling
7  # torch itself, it doesn't work because the parent module can't
8  # yet be imported. However that's fine because we don't need it in
9  # that case.
10  from .._utils_internal import get_file_path
11 
12  THNN_H_PATH = get_file_path('torch', 'include', 'THNN', 'generic', 'THNN.h')
13  THCUNN_H_PATH = get_file_path('torch', 'include', 'THCUNN', 'generic', 'THCUNN.h')
14 except Exception as e:
15  pass
16 
17 
18 def _unpickle_backend(backend_name):
19  import torch._thnn
20  return torch._thnn.type2backend[backend_name]
21 
22 
23 class THNNBackendBase(object):
24 
25  def __init__(self):
26  self.methods = {}
27 
28  def __getattr__(self, name):
29  method = self.methods.get(name, None)
30  if method is None:
31  raise NotImplementedError
32  return method
33 
34  def register_method(self, name, ctypes_fn):
35  self.methods[name] = ctypes_fn
36 
37  @property
38  def library_state(self):
39  return 0
40 
41  def __reduce__(self):
42  return (_unpickle_backend, (type(self).__name__,))
43 
44 
45 class Function(object):
46 
47  def __init__(self, name):
48  self.name = name
49  self.arguments = []
50 
51  def add_argument(self, arg):
52  assert isinstance(arg, Argument)
53  self.arguments.append(arg)
54 
55  def __repr__(self):
56  return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
57 
58 
59 class Argument(object):
60 
61  def __init__(self, _type, name, is_optional):
62  self.type = _type
63  self.name = name
64  self.is_optional = is_optional
65 
66  def __repr__(self):
67  return self.type + ' ' + self.name
68 
69 
70 def parse_header(path):
71  with open(path, 'r') as f:
72  lines = f.read().split('\n')
73 
74  # Remove empty lines and preprocessor directives
75  lines = filter(lambda l: l and not l.startswith('#'), lines)
76  # Remove line comments
77  lines = map(lambda l: l.partition('//'), lines)
78  # Select line and comment part
79  lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
80  # Remove trailing special signs
81  lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
82  # Split arguments
83  lines = map(lambda l: (l[0].split(','), l[1]), lines)
84  # Flatten lines
85  new_lines = []
86  for l, c in lines:
87  for split in l:
88  new_lines.append((split, c))
89  lines = new_lines
90  del new_lines
91  # Remove unnecessary whitespace
92  lines = map(lambda l: (l[0].strip(), l[1]), lines)
93  # Remove empty lines
94  lines = filter(lambda l: l[0], lines)
95  generic_functions = []
96  for l, c in lines:
97  if l.startswith('TH_API void THNN_'):
98  fn_name = l.lstrip('TH_API void THNN_')
99  if fn_name[0] == '(' and fn_name[-2] == ')':
100  fn_name = fn_name[1:-2]
101  else:
102  fn_name = fn_name[:-1]
103  generic_functions.append(Function(fn_name))
104  elif l.startswith('THC_API void THNN_'):
105  fn_name = l.lstrip('THC_API void THNN_')
106  if fn_name[0] == '(' and fn_name[-2] == ')':
107  fn_name = fn_name[1:-2]
108  else:
109  fn_name = fn_name[:-1]
110  generic_functions.append(Function(fn_name))
111  elif l:
112  t, name = l.split()
113  if '*' in name:
114  t = t + '*'
115  name = name[1:]
116  generic_functions[-1].add_argument(Argument(t, name, '[OPTIONAL]' in c))
117  return generic_functions
118 
119 
120 def load_backend(t, lib, generic_functions, mixins=tuple()):
121  backend_name = 'THNN{}Backend'.format(t)
122  backend = type(backend_name, mixins + (THNNBackendBase,), {})()
123  for function in generic_functions:
124  full_fn_name = '{}{}'.format(t, function.name)
125  fn = getattr(lib, full_fn_name)
126  backend.register_method(function.name, fn)
127  return backend
Module caffe2.python.layers.split.