3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 import caffe2.proto.caffe2_pb2
as caffe2_pb2
17 logger = logging.getLogger(__name__)
18 logger.setLevel(logging.INFO)
23 def __init__(self, model, input_record, output_names_or_num, function,
24 name=
'functional', output_dtypes=
None, tags=
None, **kwargs):
27 input_record = schema.as_record(input_record)
29 super(Functional, self).__init__(model, name, input_record, tags=tags, **kwargs)
33 isinstance(output_names_or_num, list)
or 34 (isinstance(output_names_or_num, six.integer_types)
and 35 output_names_or_num != 1)
38 with scope.NameScope(self.name, reset=
True):
39 if isinstance(output_names_or_num, int):
40 struct_output_schema = schema.NewRecord(
41 model.net, schema.RawTuple(output_names_or_num))
43 self.
output_schema = output_names_or_num.clone(keep_blobs=
True)
46 if not isinstance(output_names_or_num, list):
47 output_names_or_num = [output_names_or_num]
48 out_tuple = [(out, np.void)
for out
in output_names_or_num]
49 struct_output_schema = schema.NewRecord(
52 num_outputs = len(struct_output_schema.field_blobs())
63 if output_dtypes
is not None:
64 if not isinstance(output_dtypes, list):
65 output_dtypes = [output_dtypes] * num_outputs
66 assert len(output_dtypes) == num_outputs
67 for dtype, scalar
in zip(output_dtypes,
68 self.output_schema.all_scalars()):
69 scalar.set_type(dtype)
75 type_net =
core.Net(
'_temp_type_and_shape_inference_net')
76 schema.InitEmptyRecord(type_net, input_record, enforce_types=
True)
78 function(type_net, self.input_record, self.
output_schema, **kwargs)
79 (shapes, types) = workspace.InferShapesAndTypes([type_net], {})
80 for i
in range(num_outputs):
83 blob = scalar_schema()
84 if blob
not in types
or blob
not in shapes:
87 if shapes[blob] == []:
90 elif shapes[blob][0] == 0:
91 shape = tuple(shapes[blob][1:])
93 logger.warning(
"unexpeced shape: {}".format(shapes[blob]))
101 if types[blob] == caffe2_pb2.TensorProto.DOUBLE:
102 dtype = (np.float64, shape)
103 elif types[blob] == caffe2_pb2.TensorProto.FLOAT:
104 dtype = (np.float32, shape)
105 elif types[blob] == caffe2_pb2.TensorProto.INT32:
106 dtype = (np.int32, shape)
107 elif types[blob] == caffe2_pb2.TensorProto.INT64:
108 dtype = (np.int64, shape)
109 elif types[blob] == caffe2_pb2.TensorProto.FLOAT16:
110 dtype = (np.float16, shape)
112 if dtype
is not None:
113 scalar_schema.set_type(dtype)
114 except TypeError
as ex:
116 logger.warning(str(ex))
120 "Type inference had problems for layer: {}".format(self.name))
122 def add_ops(self, net):