4 """Backend for running ONNX on Caffe2 6 To run this, you will need to have Caffe2 installed as well. 8 from __future__
import absolute_import
9 from __future__
import division
10 from __future__
import print_function
11 from __future__
import unicode_literals
15 from subprocess
import Popen, PIPE
30 from caffe2.proto
import caffe2_pb2
34 from onnx
import checker, GraphProto, TensorProto, AttributeProto, ModelProto
35 import onnx.numpy_helper
38 import onnx.shape_inference
40 from onnx.backend.base
import Backend, Device, DeviceType, namedtupledict
52 return s.decode(
'utf-8')
53 except AttributeError:
56 def get_device_option(device):
57 m = {DeviceType.CPU: caffe2_pb2.CPU,
58 DeviceType.CUDA: workspace.GpuDeviceType}
59 return core.DeviceOption(m[device.type], device.device_id)
64 This is a more convenient way to work with ONNX/Caffe2 attributes 65 that is not the protobuf representation. 71 d[arg.name] = convertAttributeProto(arg)
74 def caffe2(self, kmap=lambda k: k):
75 for k, v
in self.items():
80 def convertAttributeProto(onnx_arg):
82 Convert an ONNX AttributeProto into an appropriate Python object 85 NB: Tensor attribute gets returned as the straight proto. 87 if onnx_arg.HasField(
'f'):
89 elif onnx_arg.HasField(
'i'):
91 elif onnx_arg.HasField(
's'):
93 elif onnx_arg.HasField(
't'):
95 elif onnx_arg.HasField(
'g'):
96 return Caffe2Backend._graph_to_net(onnx_arg.g, Caffe2Backend._known_opset_version)
97 elif len(onnx_arg.floats):
98 return list(onnx_arg.floats)
99 elif len(onnx_arg.ints):
100 return list(onnx_arg.ints)
101 elif len(onnx_arg.strings):
102 return list(onnx_arg.strings)
103 elif len(onnx_arg.graphs):
106 for g
in onnx_arg.graphs:
107 retval.append(Caffe2Backend._graph_to_net(g, Caffe2Backend._known_opset_version))
110 raise ValueError(
"Unsupported ONNX attribute: {}".format(onnx_arg))
116 Reimplementation of NodeProto from ONNX, but in a form 117 more convenient to work with from Python. 119 We may temporarily edit these nodes to get them into Caffe2 form, 120 before actually translating into the Caffe2 protobuf, since this 121 is easier than decomposing everything, and putting it back together 124 def __init__(self, node):
125 self.
name = str(node.name)
126 self.
op_type = str(node.op_type)
127 self.
attrs = OnnxAttributes.from_onnx(node.attribute)
128 self.
inputs = list(node.input)
129 self.
outputs = list(node.output)
132 Caffe2Ops = collections.namedtuple(
'Caffe2Ops', [
'ops',
'init_ops',
'interface_blobs'])
144 _known_opset_version = 9
149 _broken_operators = {
157 _renamed_operators = {
158 'GlobalMaxPool':
'MaxPool',
159 'GlobalAveragePool':
'AveragePool',
162 'BatchNormalization':
'SpatialBN',
163 'InstanceNormalization':
'InstanceNorm',
164 'MatMul':
'BatchMatMul',
165 'Upsample':
'ResizeNearest',
167 'InstanceNormalization':
'InstanceNorm',
171 'Unsqueeze':
'ExpandDims',
174 'RandomNormal':
'GaussianFill',
177 _global_renamed_attrs = {
'kernel_shape':
'kernels'}
178 _per_op_renamed_attrs = {
179 'Squeeze': {
'axes':
'dims'},
180 'Unsqueeze': {
'axes':
'dims'},
181 'Transpose': {
'perm':
'axes'},
182 'Upsample': {
'mode':
'',
184 'ConvTranspose': {
'output_padding':
'adjs'},
185 'Selu': {
'gamma':
'scale'},
186 'If': {
'then_branch':
'then_net',
187 'else_branch':
'else_net'}
193 _special_operators = {
194 'LSTM':
'_create_rnn_variant',
195 'GRU':
'_create_rnn_variant',
196 'RNN':
'_create_rnn_variant',
197 'Loop':
'_create_loop',
199 'Upsample':
'_create_upsample',
200 'RandomNormal':
'_create_gaussian_fill' 204 _dummy_name = C.DummyName()
208 return cls._dummy_name.new_dummy_name()
214 def run_node(cls, node, inputs, device='CPU', opset_version=_known_opset_version, outputs_info=None):
215 super(Caffe2Backend, cls).run_node(node, inputs, device=device,
216 outputs_info=outputs_info, opset_version=opset_version)
219 device_option = get_device_option(Device(device))
221 with core.DeviceScope(device_option):
222 if isinstance(inputs, dict):
223 for key, value
in inputs.items():
224 ws.FeedBlob(key, value)
225 value_infos.append(onnx.helper.make_tensor_value_info(
227 elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype],
228 shape=value.shape).SerializeToString())
230 assert len(node.input) == len(inputs),
"{}: expected {} but got {}".format(
231 node.op_type, len(node.input), len(inputs))
232 for key, value
in zip(node.input, inputs):
233 ws.FeedBlob(key, value)
234 value_infos.append(onnx.helper.make_tensor_value_info(
236 elem_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype],
237 shape=value.shape).SerializeToString())
241 ops_str = cbackend.convert_node(node.SerializeToString(), value_infos, opset_version)
242 for s
in ops_str[0] + ops_str[1]:
243 op = caffe2_pb2.OperatorDef()
244 op.ParseFromString(s)
245 op.device_option.CopyFrom(device_option)
247 ws.RunOperatorsOnce(ops)
248 output_values = [ws.FetchBlob(name)
for name
in node.output]
249 return namedtupledict(
'Outputs', node.output)(*output_values)
252 def _create_tensor_filling_op(cls, onnx_tensor, name=None):
254 Given an Onnx TensorProto, translate it into a Caffe2 operator 255 which produces the given tensor filling op. 257 assert name
or onnx_tensor.name
258 name = name
or onnx_tensor.name
260 c2_op = caffe2_pb2.OperatorDef()
262 c2_values = c2_op.arg.add()
263 c2_values.name =
"values" 265 def tensor2list(onnx_tensor):
267 return onnx.numpy_helper.to_array(onnx_tensor).flatten().tolist()
269 if onnx_tensor.data_type
in [TensorProto.FLOAT]:
270 c2_op.type =
'GivenTensorFill' 271 c2_values.floats.extend(tensor2list(onnx_tensor))
272 elif onnx_tensor.data_type
in [TensorProto.DOUBLE]:
273 c2_op.type =
'GivenTensorDoubleFill' 274 c2_values.floats.extend(tensor2list(onnx_tensor))
275 elif onnx_tensor.data_type
in [TensorProto.INT64,
277 c2_op.type =
'GivenTensorInt64Fill' 278 c2_values.ints.extend(tensor2list(onnx_tensor))
279 elif onnx_tensor.data_type
in [TensorProto.UINT8,
284 c2_op.type =
'GivenTensorIntFill' 285 c2_values.ints.extend(tensor2list(onnx_tensor))
286 elif onnx_tensor.data_type == TensorProto.BOOL:
287 c2_op.type =
'GivenTensorBoolFill' 288 c2_values.ints.extend(tensor2list(onnx_tensor))
289 elif onnx_tensor.data_type == TensorProto.STRING:
290 c2_op.type =
'GivenTensorStringFill' 291 c2_values.strings.extend(onnx_tensor.string_data)
294 "unrecognized tensor type {}".format(onnx_tensor.data_type))
296 c2_shape = c2_op.arg.add()
297 c2_shape.name =
"shape" 298 c2_shape.ints.extend(onnx_tensor.dims)
300 c2_op.output.append(name)
305 def _rnn_reform_weights(cls, reforms, name, hidden_size, init_net, gates, reorder_indices):
306 for name_from, name_to, do_concat, extra_dims
in reforms:
307 gate_blobs = [
'%s/%s_%s' % (name, prefix, name_to)
for prefix
in gates]
308 for i, x
in enumerate(gate_blobs):
309 dim0 = i * hidden_size, (i+1) * hidden_size
310 starts, ends = zip(dim0, *extra_dims)
311 init_net.Slice(name_from, x, starts=starts, ends=ends)
313 reordered_gate_blobs = [gate_blobs[i]
for i
in reorder_indices]
314 init_net.Concat(reordered_gate_blobs, [
'%s/%s' % (name, name_to), cls.
dummy_name()], axis=0)
317 def _make_rnn_direction(cls, input_blob, B, W, R, initial_states_and_names, sequence_lens,
319 input_size, hidden_size, num_gates, direction_offset,
321 reform, make_cell, keep_outputs):
326 gates_hidden_size = num_gates * hidden_size
327 bias_offset = 2 * direction_offset * gates_hidden_size
328 weight_offset = direction_offset * gates_hidden_size
329 Bi = init_net.Slice(B, name + Bi,
330 starts=[bias_offset + 0 * gates_hidden_size],
331 ends =[bias_offset + 1 * gates_hidden_size])
332 Br = init_net.Slice(B, name + Br,
333 starts=[bias_offset + 1 * gates_hidden_size],
334 ends =[bias_offset + 2 * gates_hidden_size])
335 W_ = init_net.Slice(W, name + W_,
336 starts=[weight_offset + 0 * gates_hidden_size, 0],
337 ends =[weight_offset + 1 * gates_hidden_size,-1])
338 R_ = init_net.Slice(R, name + R_,
339 starts=[weight_offset + 0 * gates_hidden_size, 0],
340 ends =[weight_offset + 1 * gates_hidden_size,-1])
342 initial_states_sliced = []
343 for initial_state, name_suffix
in initial_states_and_names:
344 initial_states_sliced.append(
345 pred_mh.net.Slice(initial_state, name + name_suffix,
346 starts=[direction_offset + 0, 0, 0],
347 ends =[direction_offset + 1,-1,-1]))
349 if direction_offset == 1:
350 if sequence_lens
is not None:
351 seq_lens_for_reverse = sequence_lens
353 input_shape = pred_mh.net.Shape(input_blob, name +
'/input_shape')
354 batch_size = pred_mh.net.Slice(input_shape, name +
'/batch_size_slice', starts=[1], ends=[2])
355 seq_len = pred_mh.net.Slice(input_shape, name +
'/seq_len_slice', starts=[0], ends=[1])
356 dummy_sequence_lens = pred_mh.net.Tile([seq_len, batch_size], name +
'/dummy_sequence_lens', axis=0)
357 pred_mh.net.Reshape(dummy_sequence_lens, [dummy_sequence_lens, cls.
dummy_name()], shape=[-1])
358 seq_lens_for_reverse = pred_mh.net.Cast(dummy_sequence_lens, name +
'/seq_lens_for_reverse', to=core.DataType.INT32)
359 reform(Bi, Br, W_, R_, name, hidden_size, init_net)
361 if direction_offset == 1:
362 input = pred_mh.net.ReversePackedSegs(
363 [input_blob, seq_lens_for_reverse], name +
"/input-reversed")
367 outputs = keep_outputs(list(make_cell(
371 initial_states_sliced,
379 if direction_offset == 1:
380 outputs[0] = pred_mh.net.ReversePackedSegs(
381 [outputs[0], seq_lens_for_reverse], name +
"/output-reversed")
386 def _create_rnn_variant(cls, init_model, pred_model, n, opset_version):
387 assert init_model
is not None,
"cannot convert RNNs without access to the full model" 388 assert pred_model
is not None,
"cannot convert RNNs without access to the full model" 390 attrs = dict(n.attrs)
391 hidden_size = attrs.pop(
'hidden_size')
392 direction = force_unicode(attrs.pop(
'direction',
'forward'))
394 if n.op_type ==
'RNN':
395 activation = force_unicode(attrs.pop(
'activations', (
'tanh',))[0])
396 elif n.op_type ==
'GRU':
397 linear_before_reset = attrs.pop(
'linear_before_reset', 0)
399 assert not attrs,
"unsupported RNN attributes: " + str(attrs.keys())
400 assert direction
in [
'forward',
'bidirectional'],
"unsupported backwards RNN/GRU/LSTM" 402 if n.op_type
in [
'RNN',
'GRU']:
403 input_blob, W, R, B, sequence_lens, initial_h = n.inputs
404 elif n.op_type ==
'LSTM':
405 input_blob, W, R, B, sequence_lens, initial_h, initial_c = n.inputs
407 if sequence_lens ==
"":
410 for x
in itertools.chain(init_model.graph.input,
411 init_model.graph.value_info,
412 pred_model.graph.input,
413 pred_model.graph.value_info):
415 input_size = x.type.tensor_type.shape.dim[2].dim_value
418 raise RuntimeError(
"best-effort shape inference for RNN/GRU/LSTM failed")
423 init_net.Reshape(W, [W, cls.
dummy_name()], shape=[1,-1,0])
424 init_net.Squeeze(W, W, dims=[0])
425 init_net.Reshape(R, [R, cls.
dummy_name()], shape=[1,-1,0])
426 init_net.Squeeze(R, R, dims=[0])
427 init_net.Reshape(B, [B, cls.
dummy_name()], shape=[1,-1])
428 init_net.Squeeze(B, B, dims=[0])
430 if n.op_type ==
'RNN':
434 def make_cell(*args, **kwargs):
435 return rnn_cell.BasicRNN(*args, activation=activation, **kwargs)
437 def make_rnn(direction_offset):
438 return cls._make_rnn_direction(
439 input_blob, B, W, R, [(initial_h,
'/initial_h')], sequence_lens,
440 pred_mh, init_net, input_size, hidden_size, 1, direction_offset,
441 "/i2h_b",
"/gates_t_b",
"/i2h_w",
"/gates_t_w",
442 reform, make_cell,
lambda x: x)
444 elif n.op_type ==
'GRU':
445 def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
448 reforms = ((W_,
'i2h_w',
True, [(0,-1)]),
449 (R_,
'gate_t_w',
False, [(0,-1)]),
450 (Bi,
'i2h_b',
True, []),
451 (Br,
'gate_t_b',
False, []))
452 cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
453 [
'update',
'reset',
'output'], [1, 0, 2])
455 def make_cell(*args, **kwargs):
456 return gru_cell.GRU(*args, linear_before_reset=linear_before_reset, **kwargs)
458 def make_rnn(direction_offset):
459 return cls._make_rnn_direction(
460 input_blob, B, W, R, [(initial_h,
'/initial_h')], sequence_lens,
461 pred_mh, init_net, input_size, hidden_size, 3, direction_offset,
462 "_bias_i2h",
"_bias_gates",
"/i2h_w_pre",
"/gates_t_w_pre",
463 reform, make_cell,
lambda x: x)
465 elif n.op_type ==
'LSTM':
466 def reform(Bi, Br, W_, R_, name, hidden_size, init_net):
469 reforms = ((W_,
'i2h_w',
True, [(0, -1)]),
470 (R_,
'gates_t_w',
True, [(0, -1)]),
471 (Bi,
'i2h_b' ,
True, []),
472 (Br,
'gates_t_b',
True, []))
473 cls._rnn_reform_weights(reforms, name, hidden_size, init_net,
474 [
'input',
'output',
'forget',
'cell'], [0, 2, 1, 3])
476 def make_cell(*args, **kwargs):
477 return rnn_cell.LSTM(*args, **kwargs)
479 def make_rnn(direction_offset):
480 return cls._make_rnn_direction(
481 input_blob, B, W, R, [(initial_h,
'/initial_h'), (initial_c,
'/initial_c')], sequence_lens,
482 pred_mh, init_net, input_size, hidden_size, 4, direction_offset,
483 "/i2h_b",
"/gates_t_b",
"/i2h_w",
"/gates_t_w",
484 reform, make_cell,
lambda x: [x[0], x[1], x[3]])
486 if direction ==
'forward':
487 outputs = make_rnn(0)
493 for i
in range(1, len(outputs)):
494 pred_mh.net.Copy(outputs[i], n.outputs[i])
496 if sequence_lens
is not None:
497 pred_mh.net.VariableLengthSequencePadding(
498 [outputs[0], sequence_lens], [outputs[0]])
499 pred_mh.net.ExpandDims([outputs[0]], [n.outputs[0]], dims=[1])
500 elif direction ==
'bidirectional':
501 outputs_f = make_rnn(0)
502 outputs_b = make_rnn(1)
504 concatted_output, _ = pred_mh.net.Concat(
505 [outputs_f[0], outputs_b[0]], [cls.dummy_name(), cls.dummy_name()], axis=2)
506 if sequence_lens
is not None:
507 pred_mh.net.VariableLengthSequencePadding(
508 [concatted_output, sequence_lens], [concatted_output])
509 reshaped_output, _ = pred_mh.net.Reshape(concatted_output, [cls.dummy_name(), cls.dummy_name()], shape=[0,0,-1,2])
510 pred_mh.net.Transpose(reshaped_output, n.outputs[0], axes=[0,2,1,3])
511 for i
in range(1, len(n.outputs)):
512 pred_mh.net.Concat([outputs_f[i], outputs_b[i]],
513 [n.outputs[i], cls.dummy_name()], axis=0)
523 initializers = {i.name
for i
in init_model.graph.initializer}
524 outputs = {output
for node
in init_model.graph.node
for output
in node.output}
525 has_initializers = all(x
in initializers
or x
in outputs
for x
in (W, R, B))
529 (init_ops
if has_initializers
else pred_ops).extend(init_net.Proto().op)
530 pred_ops.extend(pred_mh.Proto().op)
532 return Caffe2Ops(pred_ops, init_ops, list(pred_mh.Proto().external_input))
535 def _create_control_op(cls, init_model, pred_model, n, opset_version):
537 if '__control_inputs' in n.attrs:
538 control_inputs.extend(n.attrs[
'__control_inputs'])
539 node = cls._common_onnx_node_to_caffe2_op(init_model, pred_model, n, opset_version)
540 node.control_input.extend(control_inputs)
541 return Caffe2Ops([node], [], [])
544 def _remove_ssa(cls, net, remap_dict):
546 for i, name
in enumerate(op.output):
547 if name
in remap_dict:
548 op.output[i] = remap_dict[name]
549 for i, out
in enumerate(net.external_output):
550 if out
in remap_dict:
551 net.external_output[i] = remap_dict[out]
554 def _create_if(cls, init_model, pred_model, n, opset_version):
555 ops = cls._create_control_op(init_model, pred_model, n, opset_version)
556 assert ops[0][0].type ==
'If' 558 then_net = else_net =
None 560 for arg
in if_op.arg:
561 if arg.name ==
'then_net':
563 if arg.name ==
'else_net':
565 if arg.name ==
'__control_inputs':
566 control_inputs = arg.strings
568 assert then_net
and else_net
569 then_net_outs = then_net.external_output
570 else_net_outs = else_net.external_output
571 op_outputs = if_op.output
572 assert len(then_net_outs) == len(else_net_outs)
573 assert len(else_net_outs) == len(op_outputs)
575 for arg
in if_op.arg:
576 if arg.name ==
'then_net':
577 arg.n.external_input.extend(control_inputs)
578 if arg.name ==
'else_net':
579 arg.n.external_input.extend(control_inputs)
584 def _create_loop(cls, init_model, pred_model, n, opset_version):
585 ops = cls._create_control_op(init_model, pred_model, n, opset_version)
586 assert ops[0][0].type ==
'ONNXWhile' 592 for arg
in while_op.arg:
593 if arg.name ==
'__control_inputs':
594 control_inputs = arg.strings
595 num_loop_carried_deps = 0
596 for arg
in while_op.arg:
597 if arg.name ==
'body':
598 num_loop_carried_deps = len(arg.n.external_input) - 2
599 arg.n.external_input.extend(control_inputs)
600 while_op.arg.extend([
602 num_loop_carried_deps)
608 def _substitute_raw_value(cls, tp, raw_values_dict):
609 if tp.HasField(
'raw_data')
and tp.raw_data == bytes(b
'__EXTERNAL'):
610 if tp.name
not in raw_values_dict:
611 raise RuntimeError(
'TensorProto for value {} referenced raw data but it was not found!'.format(tp.name))
613 tp.raw_data = raw_values_dict[tp.name]
616 def _visit_and_substitute_raw_values(cls, nodes, raw_values_dict):
618 for attr
in node.attribute:
619 if attr.HasField(
't'):
620 cls._substitute_raw_value(attr.t, raw_values_dict)
621 for t
in attr.tensors:
622 cls._substitute_raw_value(t, raw_values_dict)
623 if attr.HasField(
'g'):
624 cls._visit_and_substitute_raw_values(attr.g.node, raw_values_dict)
625 for g
in attr.graphs:
626 cls._visit_and_substitute_raw_values(g.node, raw_values_dict)
629 def _external_value_resolution_pass(cls, model, raw_values_dict):
630 for init
in model.graph.initializer:
631 cls._substitute_raw_value(init, raw_values_dict)
633 cls._visit_and_substitute_raw_values(model.graph.node, raw_values_dict)
637 def _direct_initialize_parameters(cls, initializer, ws, device_option):
638 for tp
in initializer:
639 ws.FeedBlob(tp.name, onnx.numpy_helper.to_array(tp), device_option)
642 def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option):
643 for value_info
in inputs:
644 if value_info.name
in initialized:
646 shape = list(d.dim_value
for d
in value_info.type.tensor_type.shape.dim)
649 np.ones(shape, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[value_info.type.tensor_type.elem_type]),
653 def optimize_onnx(input, init=False, predict=False):
654 passes = [
'fuse_consecutive_transposes',
655 'eliminate_nop_transpose',
656 'fuse_transpose_into_gemm',
657 'lift_lexical_references']
659 passes.append(
'split_init')
661 passes.append(
'split_predict')
662 out = onnx.optimizer.optimize(input, passes)
666 def prepare_zip_archive(cls, file, device='CPU', **kwargs):
667 with zipfile.ZipFile(file, mode=
'r') as z: 668 with z.open('__MODEL_PROTO',
'r') as f: 669 model = onnx.load(f); 670 blob_names = set(z.namelist()) - set('__MODEL_PROTO')
673 for name
in blob_names:
674 with z.open(name,
'r') as blob_file: 675 raw_values_dict[name] = blob_file.read() 677 return cls.prepare(model, device, raw_values_dict=raw_values_dict, **kwargs)
680 def prepare(cls, model, device='CPU', raw_values_dict=None, **kwargs):
682 For Onnx Caffe2Backend, we require that init_graph don't initialize the actual input of the predict_graph, 684 for example, if "img" is the input blob for the predict_net, we require that in init_graph and in 685 initializer of the predict_graph, "img" is not initalized. We don't have a check for this, since 686 there is no way we can know which blob is the input of the predict_graph. 688 if not kwargs.pop(
'no_check_UNSAFE',
False):
689 super(Caffe2Backend, cls).
prepare(model, device, **kwargs)
691 for imp
in model.opset_import:
692 if not imp.HasField(
"domain")
or imp.domain ==
"":
693 opset_version = imp.version
695 warnings.warn(
"This version of onnx-caffe2 targets ONNX operator set version {}, but the model we are trying to import uses version {}. We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail.".format(cls.
_known_opset_version, imp.version))
697 warnings.warn(
"Unrecognized operator set {}".format(imp.domain))
698 if opset_version
is None:
699 if model.ir_version >= 0x00000003:
700 raise RuntimeError(
"Model with IR version >= 3 did not specify ONNX operator set version (onnx-caffe2 requires it)")
704 model = onnx.shape_inference.infer_shapes(model)
707 device_option = get_device_option(Device(device))
716 model.graph.initializer,
721 initialized = {init.name
for init
in model.graph.initializer}
730 uninitialized = [value_info.name
for value_info
in model.graph.input
if value_info.name
not in initialized]
732 retval =
Caffe2Rep(init_net, predict_net, ws, uninitialized)
738 def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version):
740 if cbackend.support_onnx_import(node_def.op_type):
746 for name
in node_def.input:
747 if pred_model
is not None:
748 for vi
in itertools.chain(pred_model.graph.input,
749 pred_model.graph.output,
750 pred_model.graph.value_info):
752 value_infos.append(vi.SerializeToString())
754 op_strs = cbackend.convert_node(node_def.SerializeToString(), value_infos, opset_version)
757 op = caffe2_pb2.OperatorDef()
758 op.ParseFromString(s)
762 op = caffe2_pb2.OperatorDef()
763 op.ParseFromString(s)
765 return Caffe2Ops(ops, init_ops, [])
771 ops = translator(init_model, pred_model,
OnnxNode(node_def), opset_version)
772 if isinstance(ops, Caffe2Ops):
774 if not isinstance(ops, container_abcs.Iterable):
776 return Caffe2Ops(ops, [], [])
778 _broadcast_operators = {
784 def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version):
786 This translator performs the basic translation of ONNX nodes into 787 Caffe2 operators. Besides doing a straightforward marshalling from 788 one format to another, it also does these extra things: 790 - Renames operators based on '_renamed_operators' 791 - Renames attributes based on '_global_renamed_attrs' and 792 '_per_op_renamed_attrs' 794 If you're writing a custom translator, consider calling this first, 795 and then fixing things up further. 797 c2_op = caffe2_pb2.OperatorDef()
799 c2_op.input.extend(onnx_node.inputs)
800 c2_op.output.extend(onnx_node.outputs)
801 c2_op.name = onnx_node.name
804 onnx_op_type = onnx_node.op_type
805 broken_version = cls._broken_operators.get(onnx_op_type, float(
'Inf'))
806 if broken_version <= opset_version:
808 "Don't know how to translate op {} in ONNX operator set v{} (I only support prior to v{})".format(onnx_op_type, opset_version, broken_version))
809 c2_op.type = cls._renamed_operators.get(onnx_op_type, onnx_op_type)
810 if not core.IsOperator(c2_op.type):
812 "Don't know how to translate op {}".format(onnx_op_type))
821 c2_op.arg.extend(onnx_node.attrs.caffe2(kmap=kmap))
823 if opset_version < 7:
827 already_broadcast =
False 828 for arg
in c2_op.arg:
829 if arg.name ==
'broadcast':
830 already_broadcast =
True 831 if not already_broadcast:
837 def _all_names_in_graph(graph):
842 names.update(value_info.name
for value_info
in graph.input)
843 names.update(value_info.name
for value_info
in graph.output)
844 for node
in graph.node:
845 names.update(node.input)
846 names.update(node.output)
850 def _graph_to_net(cls, onnx_graph, opset_version):
851 net = caffe2_pb2.NetDef()
852 for node
in onnx_graph.node:
855 None,
None, node, opset_version)
856 except Exception
as e:
857 print(
'ONNX FATAL:', e)
859 net.op.extend(c2ops.init_ops)
860 net.op.extend(c2ops.ops)
861 net.external_input.extend(c2ops.interface_blobs)
862 net.external_output.extend(
863 value_info.name
for value_info
in onnx_graph.output)
864 net.external_input.extend(
865 value_info.name
for value_info
in onnx_graph.input)
869 def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers):
870 device_option = get_device_option(Device(device))
872 onnx_model = onnx.utils.polish_model(onnx_model)
876 init_net = caffe2_pb2.NetDef()
877 pred_net = caffe2_pb2.NetDef()
879 init_net.name = onnx_model.graph.name +
'_init' 880 pred_net.name = onnx_model.graph.name +
'_predict' 882 if include_initializers:
888 for net, model
in ( (init_net, init_model), (pred_net, pred_model) ):
889 net.device_option.CopyFrom(device_option)
890 for node
in model.graph.node:
893 init_model, pred_model, node, opset_version)
894 except Exception
as e:
896 print(
'ONNX FATAL:', e)
898 init_net.op.extend(c2ops.init_ops)
899 net.op.extend(c2ops.ops)
900 net.external_input.extend(c2ops.interface_blobs)
901 net.external_output.extend(
902 value_info.name
for value_info
in model.graph.output)
903 net.external_input.extend(
904 value_info.name
for value_info
in model.graph.input)
907 raise RuntimeError(
'ONNX conversion failed')
909 return init_net, pred_net
913 def onnx_graph_to_caffe2_net(cls, model, device="CPU", opset_version=_known_opset_version):
917 def supports_device(cls, device_str):
918 device = Device(device_str)
919 if device.type == DeviceType.CPU:
921 elif core.IsGPUDeviceType(device.type):
922 return workspace.has_gpu_support
926 def is_compatible(cls, model, device='CPU', **kwargs):
927 if hasattr(super(Caffe2Backend, cls),
'is_compatible') \
928 and callable(super(Caffe2Backend, cls).is_compatible):
929 if not super(Caffe2Backend, cls).is_compatible(model, device, **kwargs):
934 prepare = Caffe2Backend.prepare
936 prepare_zip_archive = Caffe2Backend.prepare_zip_archive
938 run_node = Caffe2Backend.run_node
940 run_model = Caffe2Backend.run_model
942 supports_device = Caffe2Backend.supports_device
944 is_compatible = Caffe2Backend.is_compatible
def MakeArgument(key, value)
def _external_value_resolution_pass(cls, model, raw_values_dict)
def _all_names_in_graph(graph)
dictionary _per_op_renamed_attrs
dictionary _broadcast_operators
def _direct_initialize_inputs(cls, inputs, initialized, ws, device_option)
def optimize_onnx(input, init=False, predict=False)
def _create_tensor_filling_op(cls, onnx_tensor, name=None)
def _common_onnx_node_to_caffe2_op(cls, init_model, pred_model, onnx_node, opset_version)
def _direct_initialize_parameters(cls, initializer, ws, device_option)
def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_initializers)
dictionary _global_renamed_attrs
def prepare(cls, model, device='CPU', raw_values_dict=None, kwargs)
dictionary _special_operators
def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version)