Caffe2 - Python API
A deep learning, cross platform ML framework
functional.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import core, workspace
7 from caffe2.proto import caffe2_pb2
8 from caffe2.python.onnx.workspace import Workspace
9 from collections import namedtuple
10 from six import string_types
11 
12 OpSchema = workspace.C.OpSchema
13 
14 
15 def namedtupledict(typename, field_names, *args, **kwargs):
16  field_names_map = {n: i for i, n in enumerate(field_names)}
17  # Some output names are invalid python identifier, e.g. "0"
18  kwargs.setdefault('rename', True)
19  data = namedtuple(typename, field_names, *args, **kwargs)
20 
21  def getitem(self, key):
22  if isinstance(key, string_types):
23  key = field_names_map[key]
24  return super(type(self), self).__getitem__(key)
25 
26  data.__getitem__ = getitem
27  return data
28 
29 
30 class _Functional(object):
31  def __getattribute__(self, op_type):
32  def op_func(*inputs, **args):
33  ws = Workspace()
34  schema = OpSchema.get(op_type)
35  input_prefix = 'input_'
36  output_prefix = 'output_'
37 
38  def get_name_list(prefix, num, max_num):
39  return [prefix + str(x) for x in range(min(num, max_num))]
40 
41  input_names, output_names = [], []
42  input_names = get_name_list(
43  input_prefix, len(inputs), schema.max_input
44  )
45  # verify the length of input name is in range
46  # of schema
47  num_input = len(input_names)
48  if num_input > schema.max_input or num_input < \
49  schema.min_input or not schema.num_inputs_allowed(num_input):
50  raise ValueError(
51  "Functional C2: Number of inputs not in \
52  range: {} - {} or not allowed."
53  .format(schema.min_input, schema.max_input)
54  )
55 
56  if 'num_output' in args:
57  num_output = args['num_output']
58  if num_output > schema.max_output or \
59  num_output < schema.min_output or \
60  not schema.num_outputs_allowed(num_output) or \
61  not schema.num_inputs_outputs_allowed(num_input,
62  num_output):
63  raise ValueError(
64  "Functional C2: Number of output \
65  not in range: {} - {} or not allowed"
66  .format(schema.min_output, schema.max_output)
67  )
68  output_names = get_name_list(
69  output_prefix, num_output, schema.max_output
70  )
71  args.pop('num_output')
72  calculated = schema.CalculateOutput(num_input)
73  if not output_names and calculated != -1:
74  output_names = get_name_list(
75  output_prefix, calculated, schema.max_output
76  )
77 
78  if not output_names:
79  max_output = schema.max_output
80  # For an op with max_output == inf
81  # and no Output defined in schema
82  # user should pass output_size explicitly
83  if schema.inf == max_output:
84  raise ValueError(
85  "For operators with max_output == inf,\
86  user should pass num_output explicity."
87  )
88  output_names = get_name_list(
89  output_prefix, max_output, max_output
90  )
91 
92  # There could be input-output inplace enforcement; replace the
93  # output names with input ones if such enforcements exist
94  for i in range(len(input_names)):
95  for j in range(len(output_names)):
96  if schema.inplace_enforced(i, j):
97  output_names[j] = input_names[i]
98 
99  op = core.CreateOperator(
100  op_type, input_names, output_names, **args
101  )
102  device_option = args.get('device_option', core.DeviceOption(caffe2_pb2.CPU))
103  with core.DeviceScope(device_option):
104  for i, input_blob in enumerate(inputs):
105  ws.FeedBlob(input_names[i], input_blob)
106  # RunOperator
107  ws.RunOperatorOnce(op)
108  output_values = [ws.FetchBlob(x) for x in output_names]
109  return namedtupledict('output', output_names)(*output_values)
110 
111  return op_func
112 
113 
114 Functional = _Functional()