Caffe2 - Python API
A deep learning, cross platform ML framework
backend_rep.py
1 ## @package onnx
2 # Module caffe2.python.onnx.backend_rep
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core
9 from caffe2.proto import caffe2_pb2
10 from onnx.backend.base import BackendRep, namedtupledict
11 
12 class Caffe2Rep(BackendRep):
13  def __init__(self, init_net, predict_net, workspace, uninitialized):
14  super(Caffe2Rep, self).__init__()
15  self.init_net = init_net
16  self.predict_net = predict_net
17  self.workspace = workspace
18  # The list of uninitialized external_inputs in workspace, we need this to
19  # pair the name with given sequence inputs.
20  self.uninitialized = uninitialized
21  self.nets_created = False
22  self.ran_init_net = False
23 
24  @property
25  def _name_scope(self):
26  if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
27  return 'gpu_{}'.format(self.predict_net.device_option.device_id)
28  return ''
29 
30  def run(self, inputs, **kwargs):
31  super(Caffe2Rep, self).run(inputs, **kwargs)
32  with core.DeviceScope(self.predict_net.device_option):
33  if isinstance(inputs, dict):
34  with core.NameScope(self._name_scope):
35  for key, value in inputs.items():
36  self.workspace.FeedBlob(key, value)
37  elif isinstance(inputs, list) or isinstance(inputs, tuple):
38  if len(self.uninitialized) != len(inputs):
39  raise RuntimeError('Expected {} values for uninitialized '
40  'graph inputs ({}), but got {}.'.format(
41  len(self.uninitialized),
42  ', '.join(self.uninitialized),
43  len(inputs)))
44  for i, value in enumerate(inputs):
45  # namescope already baked into protobuf
46  self.workspace.FeedBlob(self.uninitialized[i], value)
47  else:
48  # single input
49  self.workspace.FeedBlob(self.uninitialized[0], inputs)
50  if not self.nets_created:
51  self.workspace.CreateNet(self.init_net)
52  self.workspace.CreateNet(self.predict_net)
53  self.nets_created = True
54  if not self.ran_init_net:
55  self.workspace.RunNet(self.init_net.name)
56  self.ran_init_net = True
57  self.workspace.RunNet(self.predict_net.name)
58  output_values = [self.workspace.FetchBlob(name)
59  for name in self.predict_net.external_output]
60  return namedtupledict('Outputs',
61  self.predict_net.external_output)(*output_values)