3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from caffe2.proto
import caffe2_pb2
10 from onnx.backend.base
import BackendRep, namedtupledict
13 def __init__(self, init_net, predict_net, workspace, uninitialized):
14 super(Caffe2Rep, self).__init__()
25 def _name_scope(self):
26 if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
27 return 'gpu_{}'.format(self.predict_net.device_option.device_id)
30 def run(self, inputs, **kwargs):
31 super(Caffe2Rep, self).run(inputs, **kwargs)
32 with core.DeviceScope(self.predict_net.device_option):
33 if isinstance(inputs, dict):
35 for key, value
in inputs.items():
36 self.workspace.FeedBlob(key, value)
37 elif isinstance(inputs, list)
or isinstance(inputs, tuple):
39 raise RuntimeError(
'Expected {} values for uninitialized ' 40 'graph inputs ({}), but got {}.'.format(
44 for i, value
in enumerate(inputs):
51 self.workspace.CreateNet(self.
init_net)
55 self.workspace.RunNet(self.init_net.name)
57 self.workspace.RunNet(self.predict_net.name)
58 output_values = [self.workspace.FetchBlob(name)
59 for name
in self.predict_net.external_output]
60 return namedtupledict(
'Outputs',
61 self.predict_net.external_output)(*output_values)