4 """Caffe2 Protobuf to ONNX converter 6 To run this, you will need to have Caffe2 installed as well. 9 from __future__
import absolute_import
10 from __future__
import division
11 from __future__
import print_function
12 from __future__
import unicode_literals
20 from caffe2.proto
import caffe2_legacy_pb2
22 from onnx
import (defs, checker, helper, numpy_helper, mapping,
23 ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto)
24 from onnx.helper
import make_tensor, make_tensor_value_info, make_attribute, make_model
32 logging.basicConfig(level=logging.INFO)
33 logger = logging.getLogger(__name__)
41 target_opset_version = 9
43 _renamed_operators = {
44 'SpatialBN':
'BatchNormalization',
48 'ConvTranspose1D':
'ConvTranspose',
49 'ConvTranspose2D':
'ConvTranspose',
50 'ConvTranspose3D':
'ConvTranspose',
51 'MaxPool1D':
'MaxPool',
52 'MaxPool2D':
'MaxPool',
53 'MaxPool3D':
'MaxPool',
54 'AveragePool1D':
'AveragePool',
55 'AveragePool2D':
'AveragePool',
56 'AveragePool3D':
'AveragePool',
60 _blacklist_caffe2_args = {
62 'cudnn_exhaustive_search': {0, 1},
63 'exhaustive_search': {0, 1},
67 _global_renamed_args = {
68 'kernels':
'kernel_shape',
71 _per_op_renamed_args = {
72 'Squeeze': {
'dims':
'axes'},
73 'Transpose': {
'axes':
'perm'},
76 _special_operators = {}
79 _dummy_name = C.DummyName()
83 return cls._dummy_name.new_dummy_name()
86 def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg):
89 name = cls._global_renamed_args.get(arg.name, arg.name)
97 elif arg.HasField(
'i'):
99 elif arg.HasField(
's'):
108 raise ValueError(
'Could not find data field in arg: {}'.format(arg))
114 return helper.make_attribute(name, value)
117 def caffe2_arg_to_onnx_attr(cls, op_def, arg):
121 def _common_caffe2_op_to_onnx_node(cls, op_def, shapes):
122 node_def = NodeProto()
123 node_def.name = op_def.name
125 node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type)
127 node_def.input.extend(op_def.input)
128 node_def.output.extend(op_def.output)
131 for arg
in op_def.arg])
132 node_def.attribute.extend(attrs)
137 def caffe2_op_to_onnx_node(cls, op_def, shapes):
138 if C.support_onnx_export(op_def.type):
139 node_strs, tensor_strs = C.export_to_onnx(cls.
_dummy_name, op_def.SerializeToString(), shapes)
143 node.ParseFromString(s)
146 for s
in tensor_strs:
147 tensor = TensorProto()
148 tensor.ParseFromString(s)
149 const_tensors.append(tensor)
150 return nodes, const_tensors
155 nodes = translator(op_def, shapes)
157 if isinstance(nodes, tuple):
158 nodes, const_tensors = nodes
159 if not isinstance(nodes, container_abcs.Iterable):
161 return nodes, const_tensors
164 def _all_names_in_net(net):
169 names.update(net.external_input)
170 names.update(net.external_output)
172 names.update(op.input)
173 names.update(op.output)
177 def _extract_value_info(tensor):
178 return make_tensor_value_info(
180 elem_type=tensor.data_type,
184 def caffe2_net_to_onnx_graph(cls,
188 if value_info
is None:
190 if not isinstance(value_info, dict):
191 raise ValueError(
'Please pass value_info as a ' 192 'name -> (type, shape) dictionary')
199 value_info.update({init.name: (init.data_type, init.dims)
200 for init
in initializer})
206 run_native_net =
False 207 for op
in predict_net.op:
208 for name
in itertools.chain(op.input, op.output):
209 if name
not in value_info:
210 run_native_net =
True 214 missing = (set(list(predict_net.external_input)) -
215 set(value_info.keys()))
217 raise RuntimeError(
'Could not find value info of inputs: {}'.format(
224 for name
in predict_net.external_input:
225 elem_type, shape = value_info[name]
226 inputs[name] = np.random.randn(*shape).astype(
227 mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
229 ws, outputs = c2_native_run_net(
234 for name
in predict_net.external_output:
235 output = outputs[name]
236 elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype]
238 value_info[name] = (elem_type, shape)
240 graph_def = GraphProto()
241 graph_def.name = predict_net.name
242 graph_def.initializer.extend(initializer)
244 graph_def.input.extend(
245 make_tensor_value_info(
247 elem_type=value_info[name][0],
248 shape=value_info[name][1])
249 for name
in predict_net.external_input)
253 for op
in predict_net.op:
255 for name
in itertools.chain(op.input, op.output):
257 blob = ws.FetchBlob(name)
258 if hasattr(blob,
'shape'):
259 shapes[name] = blob.shape
261 shapes[name] = value_info[name][1]
263 graph_def.node.extend(nodes)
264 graph_def.initializer.extend(const_tensors)
267 all_output = set(sum((list(node.output)
for node
in graph_def.node),
268 [init.name
for init
in graph_def.initializer]))
269 redundant_output = set(vi.name
for vi
in graph_def.output) - all_output
272 'There are graph output not produced by any node or initializer: {}' 273 '! Will drop them.'.format(
', '.join(redundant_output)))
274 graph_def.output.extend(
275 make_tensor_value_info(
277 elem_type=value_info[name][0],
278 shape=value_info[name][1])
279 for name
in predict_net.external_output
280 if name
in all_output)
285 def caffe2_init_net_to_initializer(cls, init_net):
286 ws, _ = c2_native_run_net(init_net=
None, predict_net=init_net, inputs=[])
288 for op
in init_net.op:
289 output_names.extend(op.output)
290 initializer = [numpy_helper.from_array(ws.FetchBlob(name), name=name)
291 for name
in sorted(set(output_names))]
295 def _filter_fake_init(cls, init_net, value_info):
297 fake_inits = [op
for op
in init_net.op
298 if len(op.output) == 1
and op.output[0]
in value_info
and 299 re.match(
'GivenTensor.*Fill|ConstantFill', op.type)]
300 for fake_init
in fake_inits:
301 init_net.op.remove(fake_init)
306 def ssa_rewrite(cls, net, init_net, value_info):
310 def _ssa_rewrite(cls, net, init_net, value_info):
311 def ssa_name(name, version, version_cnt=None):
314 if version_cnt
and len(version_cnt.get(name, {})) <= 1:
316 return '{}_{}'.format(name, version)
319 for op
in init_net.op:
320 assert re.match(
'GivenTensor.*Fill', op.type),
"type is {}, \n{}".format(op.type, op)
321 assert len(op.output) == 1
323 ssa, blob_versions = caffe2_core.get_ssa(net)
326 for versioned_input, versioned_output
in ssa:
327 versioned_blobs += versioned_input
328 versioned_blobs += versioned_output
330 for (name, version)
in versioned_blobs:
331 if name
not in version_cnt:
332 version_cnt[name] = {version}
334 version_cnt[name].add(version)
336 assert len(net.op) == len(ssa)
337 for op, (versioned_inputs, versioned_outputs)
in zip(net.op, ssa):
338 op.input[:] = [ssa_name(name, version, version_cnt)
339 for name, version
in versioned_inputs]
340 op.output[:] = [ssa_name(name, version, version_cnt)
341 for name, version
in versioned_outputs]
342 net.external_output[:] = [ssa_name(name, blob_versions[name], version_cnt)
343 for name
in net.external_output]
346 def caffe2_net_to_onnx_model(cls, *args, **kwargs):
347 opset_id = OperatorSetIdProto()
351 opset_imports=[opset_id],
352 producer_name=
'onnx-caffe2',
354 checker.check_model(model)
358 caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph
359 caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model
360 caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer
361 ssa_rewrite = Caffe2Frontend.ssa_rewrite
def _common_caffe2_op_to_onnx_node(cls, op_def, shapes)
dictionary _special_operators
dictionary _blacklist_caffe2_args
def caffe2_op_to_onnx_node(cls, op_def, shapes)
def _filter_fake_init(cls, init_net, value_info)
def caffe2_net_to_onnx_graph(cls, predict_net, init_net=None, value_info=None)
def _all_names_in_net(net)
def caffe2_init_net_to_initializer(cls, init_net)
dictionary _per_op_renamed_args
def caffe2_arg_to_onnx_attr(cls, op_def, arg)
def _ssa_rewrite(cls, net, init_net, value_info)
def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg)
def _extract_value_info(tensor)