15 def __init__(self, fn):
18 def get(self, name, _):
22 def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None):
24 Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. 27 helper (caffe2.python.core.ModelHelder): the model helper where 28 this imported network should be inserted 29 model (torch.nn.Module): the model to be exported 30 sample_arguments (tuple of arguments): the inputs to 31 the model, e.g., such that ``model(*args)`` is a valid 32 invocation of the model. Any non-Variable arguments will 33 be hard-coded into the exported model; any Variable arguments 34 will become inputs of the exported model, in the order they 35 occur in args. If args is a Variable, this is equivalent 36 to having called it with a 1-ary tuple of that Variable. 37 (Note: passing keyword arguments to the model is not currently 38 supported. Give us a shout if you need it.) 39 caffe2_inputs (list of str or caffe2.python.core.BlobReference): the 40 caffe2 Blobs that should be inputs to this network. Must be 41 the same length as sample_arguments 42 prefix_name: prefix name to add to each member of the blob, if None then 43 a fresh prefix pytorch_input_N/ is used 45 A tuple of caffe2.python.core.BlobReference objects referring to the 46 models outputs, or a single BlobReference when the model returns a single 49 if prefix_name
is None:
51 prefix_name =
'pytorch_import_' + str(_next_idx) +
'/' 58 model, sample_arguments, f, export_params=
True)
59 onnx_model = onnx.load(io.BytesIO(f.getvalue()))
60 init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(
63 initialized = set([x.name
for x
in onnx_model.graph.initializer])
64 uninitialized_inputs = {x.name: i
for i, x
in enumerate(
65 onnx_model.graph.input)
if x.name
not in initialized}
67 if(len(uninitialized_inputs) != len(caffe2_inputs)):
68 raise ValueError(
'Expected {} inputs but found {}'.format(
69 len(uninitialized_inputs), len(caffe2_inputs)))
71 def remap_blob_name(name):
72 if name
in uninitialized_inputs:
73 idx = uninitialized_inputs[name]
74 return str(caffe2_inputs[idx])
75 return prefix_name + name
77 predict_net =
Net(predict_net).Clone(
'anon',
_FakeDict(remap_blob_name))
78 helper.net.AppendNet(predict_net)
80 init_net =
Net(init_net).Clone(
'anon',
_FakeDict(remap_blob_name))
81 helper.param_init_net.AppendNet(init_net)
83 results = tuple([
BlobReference(remap_blob_name(x.name), helper.net)
84 for x
in onnx_model.graph.output])