Caffe2 - Python API
A deep learning, cross platform ML framework
pytorch_helper.py
1 import io
2 import torch.onnx
3 import onnx
4 from caffe2.python.onnx.backend import Caffe2Backend
5 from caffe2.python.core import BlobReference, Net
6 
7 
8 _next_idx = 0
9 # Clone net takes a dict instead of a lambda
10 # It should probably take a lambda, it is more flexible
11 # We fake dict here
12 
13 
14 class _FakeDict(object):
15  def __init__(self, fn):
16  self.fn = fn
17 
18  def get(self, name, _):
19  return self.fn(name)
20 
21 
22 def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
23  """
24  Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built.
25 
26  Arguments:
27  helper (caffe2.python.core.ModelHelder): the model helper where
28  this imported network should be inserted
29  model (torch.nn.Module): the model to be exported
30  sample_arguments (tuple of arguments): the inputs to
31  the model, e.g., such that ``model(*args)`` is a valid
32  invocation of the model. Any non-Variable arguments will
33  be hard-coded into the exported model; any Variable arguments
34  will become inputs of the exported model, in the order they
35  occur in args. If args is a Variable, this is equivalent
36  to having called it with a 1-ary tuple of that Variable.
37  (Note: passing keyword arguments to the model is not currently
38  supported. Give us a shout if you need it.)
39  caffe2_inputs (list of str or caffe2.python.core.BlobReference): the
40  caffe2 Blobs that should be inputs to this network. Must be
41  the same length as sample_arguments
42  prefix_name: prefix name to add to each member of the blob, if None then
43  a fresh prefix pytorch_input_N/ is used
44  Returns:
45  A tuple of caffe2.python.core.BlobReference objects referring to the
46  models outputs, or a single BlobReference when the model returns a single
47  value.
48  """
49  if prefix_name is None:
50  global _next_idx
51  prefix_name = 'pytorch_import_' + str(_next_idx) + '/'
52  _next_idx += 1
53 
54  # TODO: handle the case where model cannot be exported
55  # and embed as a Python op in Caffe2
56  f = io.BytesIO()
58  model, sample_arguments, f, export_params=True)
59  onnx_model = onnx.load(io.BytesIO(f.getvalue()))
60  init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(
61  onnx_model)
62 
63  initialized = set([x.name for x in onnx_model.graph.initializer])
64  uninitialized_inputs = {x.name: i for i, x in enumerate(
65  onnx_model.graph.input) if x.name not in initialized}
66 
67  if(len(uninitialized_inputs) != len(caffe2_inputs)):
68  raise ValueError('Expected {} inputs but found {}'.format(
69  len(uninitialized_inputs), len(caffe2_inputs)))
70 
71  def remap_blob_name(name):
72  if name in uninitialized_inputs:
73  idx = uninitialized_inputs[name]
74  return str(caffe2_inputs[idx])
75  return prefix_name + name
76 
77  predict_net = Net(predict_net).Clone('anon', _FakeDict(remap_blob_name))
78  helper.net.AppendNet(predict_net)
79 
80  init_net = Net(init_net).Clone('anon', _FakeDict(remap_blob_name))
81  helper.param_init_net.AppendNet(init_net)
82 
83  results = tuple([BlobReference(remap_blob_name(x.name), helper.net)
84  for x in onnx_model.graph.output])
85  return results
def export(args, kwargs)
Definition: __init__.py:25