Caffe2 - Python API
A deep learning, cross platform ML framework
frontend.py
1 ## @package onnx
2 # Module caffe2.python.onnx.frontend
3 
4 """Caffe2 Protobuf to ONNX converter
5 
6 To run this, you will need to have Caffe2 installed as well.
7 """
8 
9 from __future__ import absolute_import
10 from __future__ import division
11 from __future__ import print_function
12 from __future__ import unicode_literals
13 
14 import itertools
15 import logging
16 import re
17 
18 from caffe2.python import core as caffe2_core
19 from caffe2.python.compatibility import container_abcs
20 from caffe2.proto import caffe2_legacy_pb2
21 from enum import Enum
22 from onnx import (defs, checker, helper, numpy_helper, mapping,
23  ModelProto, GraphProto, NodeProto, AttributeProto, TensorProto, OperatorSetIdProto)
24 from onnx.helper import make_tensor, make_tensor_value_info, make_attribute, make_model
25 import numpy as np
26 
27 from caffe2.python.onnx.helper import c2_native_run_net
28 from caffe2.python.onnx.error import Unsupported
29 
31 
32 logging.basicConfig(level=logging.INFO)
33 logger = logging.getLogger(__name__)
34 
35 
36 class Caffe2Frontend(object):
37  # This number controls the semantics of the operators we target. Whenever
38  # ONNX makes a BC breaking change to semantics of operators, having this set
39  # to an accurate number will prevent our models form exporting. However,
40  # we should strive to keep this up-to-date as much as possible.
41  target_opset_version = 9
42 
43  _renamed_operators = {
44  'SpatialBN': 'BatchNormalization',
45  'Conv1D': 'Conv',
46  'Conv2D': 'Conv',
47  'Conv3D': 'Conv',
48  'ConvTranspose1D': 'ConvTranspose',
49  'ConvTranspose2D': 'ConvTranspose',
50  'ConvTranspose3D': 'ConvTranspose',
51  'MaxPool1D': 'MaxPool',
52  'MaxPool2D': 'MaxPool',
53  'MaxPool3D': 'MaxPool',
54  'AveragePool1D': 'AveragePool',
55  'AveragePool2D': 'AveragePool',
56  'AveragePool3D': 'AveragePool',
57  }
58 
59  # caffe2 arguments that are completely removed in onnx
60  _blacklist_caffe2_args = {
61  'order': {b'NCHW'},
62  'cudnn_exhaustive_search': {0, 1},
63  'exhaustive_search': {0, 1},
64  'use_cudnn': {0, 1},
65  }
66 
67  _global_renamed_args = {
68  'kernels': 'kernel_shape',
69  }
70 
71  _per_op_renamed_args = {
72  'Squeeze': {'dims': 'axes'},
73  'Transpose': {'axes': 'perm'},
74  }
75 
76  _special_operators = {}
77 
78  # Dummy name generator
79  _dummy_name = C.DummyName()
80 
81  @classmethod
82  def dummy_name(cls):
83  return cls._dummy_name.new_dummy_name()
84 
85  @classmethod
86  def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg):
87  # name
88  op_type = op_def.type
89  name = cls._global_renamed_args.get(arg.name, arg.name)
90  if op_type in cls._per_op_renamed_args:
91  # Per-op attribute renames override the global attribute renames
92  name = cls._per_op_renamed_args[op_type].get(arg.name, name)
93 
94  # value
95  if arg.HasField('f'):
96  value = arg.f
97  elif arg.HasField('i'):
98  value = arg.i
99  elif arg.HasField('s'):
100  value = arg.s
101  elif arg.floats:
102  value = arg.floats
103  elif arg.ints:
104  value = arg.ints
105  elif arg.strings:
106  value = arg.strings
107  else:
108  raise ValueError('Could not find data field in arg: {}'.format(arg))
109 
110  if name in cls._blacklist_caffe2_args:
111  assert value in cls._blacklist_caffe2_args[arg.name]
112  return None
113 
114  return helper.make_attribute(name, value)
115 
116  @classmethod
117  def caffe2_arg_to_onnx_attr(cls, op_def, arg):
118  return cls._common_caffe2_arg_to_onnx_attr(op_def, arg)
119 
120  @classmethod
121  def _common_caffe2_op_to_onnx_node(cls, op_def, shapes):
122  node_def = NodeProto()
123  node_def.name = op_def.name
124 
125  node_def.op_type = cls._renamed_operators.get(op_def.type, op_def.type)
126 
127  node_def.input.extend(op_def.input)
128  node_def.output.extend(op_def.output)
129 
130  attrs = filter(None, [cls.caffe2_arg_to_onnx_attr(op_def, arg)
131  for arg in op_def.arg])
132  node_def.attribute.extend(attrs)
133 
134  return node_def
135 
136  @classmethod
137  def caffe2_op_to_onnx_node(cls, op_def, shapes):
138  if C.support_onnx_export(op_def.type):
139  node_strs, tensor_strs = C.export_to_onnx(cls._dummy_name, op_def.SerializeToString(), shapes)
140  nodes = []
141  for s in node_strs:
142  node = NodeProto()
143  node.ParseFromString(s)
144  nodes.append(node)
145  const_tensors = []
146  for s in tensor_strs:
147  tensor = TensorProto()
148  tensor.ParseFromString(s)
149  const_tensors.append(tensor)
150  return nodes, const_tensors
151  elif op_def.type in cls._special_operators:
152  translator = getattr(cls, cls._special_operators[op_def.type])
153  else:
154  translator = cls._common_caffe2_op_to_onnx_node
155  nodes = translator(op_def, shapes)
156  const_tensors = []
157  if isinstance(nodes, tuple):
158  nodes, const_tensors = nodes
159  if not isinstance(nodes, container_abcs.Iterable):
160  nodes = [nodes]
161  return nodes, const_tensors
162 
163  @staticmethod
164  def _all_names_in_net(net):
165  if net is None:
166  return set()
167 
168  names = set()
169  names.update(net.external_input)
170  names.update(net.external_output)
171  for op in net.op:
172  names.update(op.input)
173  names.update(op.output)
174  return names
175 
176  @staticmethod
177  def _extract_value_info(tensor):
178  return make_tensor_value_info(
179  name=tensor.name,
180  elem_type=tensor.data_type,
181  shape=tensor.dims)
182 
183  @classmethod
184  def caffe2_net_to_onnx_graph(cls,
185  predict_net,
186  init_net=None,
187  value_info=None):
188  if value_info is None:
189  value_info = {}
190  if not isinstance(value_info, dict):
191  raise ValueError('Please pass value_info as a '
192  'name -> (type, shape) dictionary')
193 
194  cls._filter_fake_init(init_net, value_info)
195  cls._ssa_rewrite(predict_net, init_net, value_info)
196 
197  if init_net:
198  initializer = cls.caffe2_init_net_to_initializer(init_net)
199  value_info.update({init.name: (init.data_type, init.dims)
200  for init in initializer})
201  else:
202  initializer = []
203 
204  # Check if value_info contains the types/shapes of all the blobs, in
205  # which case we don't need to infer them by running the net.
206  run_native_net = False
207  for op in predict_net.op:
208  for name in itertools.chain(op.input, op.output):
209  if name not in value_info:
210  run_native_net = True
211  break
212 
213  # Check whether we have got type shape info of all input
214  missing = (set(list(predict_net.external_input)) -
215  set(value_info.keys()))
216  if missing:
217  raise RuntimeError('Could not find value info of inputs: {}'.format(
218  ', '.join(missing)))
219 
220  ws = None
221  outputs = None
222  if run_native_net:
223  inputs = {}
224  for name in predict_net.external_input:
225  elem_type, shape = value_info[name]
226  inputs[name] = np.random.randn(*shape).astype(
227  mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
228 
229  ws, outputs = c2_native_run_net(
230  init_net,
231  predict_net,
232  inputs)
233 
234  for name in predict_net.external_output:
235  output = outputs[name]
236  elem_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output.dtype]
237  shape = output.shape
238  value_info[name] = (elem_type, shape)
239 
240  graph_def = GraphProto()
241  graph_def.name = predict_net.name
242  graph_def.initializer.extend(initializer)
243  # This is a mapping from Caffe2 names to ONNX names
244  graph_def.input.extend(
245  make_tensor_value_info(
246  name=name,
247  elem_type=value_info[name][0],
248  shape=value_info[name][1])
249  for name in predict_net.external_input)
250 
251  cls._dummy_name.reset(cls._all_names_in_net(predict_net) | cls._all_names_in_net(init_net))
252 
253  for op in predict_net.op:
254  shapes = {}
255  for name in itertools.chain(op.input, op.output):
256  if ws:
257  blob = ws.FetchBlob(name)
258  if hasattr(blob, 'shape'):
259  shapes[name] = blob.shape
260  else:
261  shapes[name] = value_info[name][1]
262  nodes, const_tensors = cls.caffe2_op_to_onnx_node(op, shapes=shapes)
263  graph_def.node.extend(nodes)
264  graph_def.initializer.extend(const_tensors)
265  graph_def.input.extend([cls._extract_value_info(tensor) for tensor in const_tensors])
266 
267  all_output = set(sum((list(node.output) for node in graph_def.node),
268  [init.name for init in graph_def.initializer]))
269  redundant_output = set(vi.name for vi in graph_def.output) - all_output
270  if redundant_output:
271  logger.warning(
272  'There are graph output not produced by any node or initializer: {}'
273  '! Will drop them.'.format(', '.join(redundant_output)))
274  graph_def.output.extend(
275  make_tensor_value_info(
276  name=name,
277  elem_type=value_info[name][0],
278  shape=value_info[name][1])
279  for name in predict_net.external_output
280  if name in all_output)
281 
282  return graph_def
283 
284  @classmethod
285  def caffe2_init_net_to_initializer(cls, init_net):
286  ws, _ = c2_native_run_net(init_net=None, predict_net=init_net, inputs=[])
287  output_names = []
288  for op in init_net.op:
289  output_names.extend(op.output)
290  initializer = [numpy_helper.from_array(ws.FetchBlob(name), name=name)
291  for name in sorted(set(output_names))]
292  return initializer
293 
294  @classmethod
295  def _filter_fake_init(cls, init_net, value_info):
296  if init_net:
297  fake_inits = [op for op in init_net.op
298  if len(op.output) == 1 and op.output[0] in value_info and
299  re.match('GivenTensor.*Fill|ConstantFill', op.type)]
300  for fake_init in fake_inits:
301  init_net.op.remove(fake_init)
302  del fake_inits[:]
303  del fake_inits
304 
305  @classmethod
306  def ssa_rewrite(cls, net, init_net, value_info):
307  return cls._ssa_rewrite(net, init_net, value_info)
308 
309  @classmethod
310  def _ssa_rewrite(cls, net, init_net, value_info):
311  def ssa_name(name, version, version_cnt=None):
312  if version == 0:
313  return name
314  if version_cnt and len(version_cnt.get(name, {})) <= 1:
315  return name
316  return '{}_{}'.format(name, version)
317 
318  if init_net:
319  for op in init_net.op:
320  assert re.match('GivenTensor.*Fill', op.type), "type is {}, \n{}".format(op.type, op)
321  assert len(op.output) == 1
322 
323  ssa, blob_versions = caffe2_core.get_ssa(net)
324  version_cnt = {}
325  versioned_blobs = []
326  for versioned_input, versioned_output in ssa:
327  versioned_blobs += versioned_input
328  versioned_blobs += versioned_output
329 
330  for (name, version) in versioned_blobs:
331  if name not in version_cnt:
332  version_cnt[name] = {version}
333  else:
334  version_cnt[name].add(version)
335 
336  assert len(net.op) == len(ssa)
337  for op, (versioned_inputs, versioned_outputs) in zip(net.op, ssa):
338  op.input[:] = [ssa_name(name, version, version_cnt)
339  for name, version in versioned_inputs]
340  op.output[:] = [ssa_name(name, version, version_cnt)
341  for name, version in versioned_outputs]
342  net.external_output[:] = [ssa_name(name, blob_versions[name], version_cnt)
343  for name in net.external_output]
344 
345  @classmethod
346  def caffe2_net_to_onnx_model(cls, *args, **kwargs):
347  opset_id = OperatorSetIdProto()
348  opset_id.domain = '' # ONNX default domain
349  opset_id.version = cls.target_opset_version
350  model = make_model(cls.caffe2_net_to_onnx_graph(*args, **kwargs),
351  opset_imports=[opset_id], # current supported opset version
352  producer_name='onnx-caffe2', # producer name
353  )
354  checker.check_model(model)
355  return model
356 
357 
358 caffe2_net_to_onnx_graph = Caffe2Frontend.caffe2_net_to_onnx_graph
359 caffe2_net_to_onnx_model = Caffe2Frontend.caffe2_net_to_onnx_model
360 caffe2_init_net_to_initializer = Caffe2Frontend.caffe2_init_net_to_initializer
361 ssa_rewrite = Caffe2Frontend.ssa_rewrite
def _common_caffe2_op_to_onnx_node(cls, op_def, shapes)
Definition: frontend.py:121
def caffe2_op_to_onnx_node(cls, op_def, shapes)
Definition: frontend.py:137
def _filter_fake_init(cls, init_net, value_info)
Definition: frontend.py:295
def caffe2_net_to_onnx_graph(cls, predict_net, init_net=None, value_info=None)
Definition: frontend.py:187
def caffe2_init_net_to_initializer(cls, init_net)
Definition: frontend.py:285
def caffe2_arg_to_onnx_attr(cls, op_def, arg)
Definition: frontend.py:117
def _ssa_rewrite(cls, net, init_net, value_info)
Definition: frontend.py:310
def _common_caffe2_arg_to_onnx_attr(cls, op_def, arg)
Definition: frontend.py:86