1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
7 from caffe2.proto
import caffe2_pb2
9 from collections
import namedtuple
10 from six
import string_types
12 OpSchema = workspace.C.OpSchema
15 def namedtupledict(typename, field_names, *args, **kwargs):
16 field_names_map = {n: i
for i, n
in enumerate(field_names)}
18 kwargs.setdefault(
'rename',
True)
19 data = namedtuple(typename, field_names, *args, **kwargs)
21 def getitem(self, key):
22 if isinstance(key, string_types):
23 key = field_names_map[key]
24 return super(type(self), self).__getitem__(key)
26 data.__getitem__ = getitem
31 def __getattribute__(self, op_type):
32 def op_func(*inputs, **args):
34 schema = OpSchema.get(op_type)
35 input_prefix =
'input_' 36 output_prefix =
'output_' 38 def get_name_list(prefix, num, max_num):
39 return [prefix + str(x)
for x
in range(min(num, max_num))]
41 input_names, output_names = [], []
42 input_names = get_name_list(
43 input_prefix, len(inputs), schema.max_input
47 num_input = len(input_names)
48 if num_input > schema.max_input
or num_input < \
49 schema.min_input
or not schema.num_inputs_allowed(num_input):
51 "Functional C2: Number of inputs not in \ 52 range: {} - {} or not allowed." 53 .format(schema.min_input, schema.max_input)
56 if 'num_output' in args:
57 num_output = args[
'num_output']
58 if num_output > schema.max_output
or \
59 num_output < schema.min_output
or \
60 not schema.num_outputs_allowed(num_output)
or \
61 not schema.num_inputs_outputs_allowed(num_input,
64 "Functional C2: Number of output \ 65 not in range: {} - {} or not allowed" 66 .format(schema.min_output, schema.max_output)
68 output_names = get_name_list(
69 output_prefix, num_output, schema.max_output
71 args.pop(
'num_output')
72 calculated = schema.CalculateOutput(num_input)
73 if not output_names
and calculated != -1:
74 output_names = get_name_list(
75 output_prefix, calculated, schema.max_output
79 max_output = schema.max_output
83 if schema.inf == max_output:
85 "For operators with max_output == inf,\ 86 user should pass num_output explicity." 88 output_names = get_name_list(
89 output_prefix, max_output, max_output
94 for i
in range(len(input_names)):
95 for j
in range(len(output_names)):
96 if schema.inplace_enforced(i, j):
97 output_names[j] = input_names[i]
99 op = core.CreateOperator(
100 op_type, input_names, output_names, **args
102 device_option = args.get(
'device_option', core.DeviceOption(caffe2_pb2.CPU))
103 with core.DeviceScope(device_option):
104 for i, input_blob
in enumerate(inputs):
105 ws.FeedBlob(input_names[i], input_blob)
107 ws.RunOperatorOnce(op)
108 output_values = [ws.FetchBlob(x)
for x
in output_names]
109 return namedtupledict(
'output', output_names)(*output_values)