Caffe2 - Python API
A deep learning, cross platform ML framework
mobile_exporter.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package mobile_exporter
17 # Module caffe2.python.mobile_exporter
18 
19 from __future__ import absolute_import
20 from __future__ import division
21 from __future__ import print_function
22 from __future__ import unicode_literals
23 from caffe2.python import core, utils
24 from caffe2.proto import caffe2_pb2
25 
26 
27 def Export(workspace, net, params):
28  """Returns init_net and predict_net suitable for writing to disk
29  and loading into a Predictor"""
30  proto = net if isinstance(net, caffe2_pb2.NetDef) else net.Proto()
31  predict_net = caffe2_pb2.NetDef()
32  predict_net.CopyFrom(proto)
33  init_net = caffe2_pb2.NetDef()
34  # Populate the init_net.
35  ssa, blob_versions = core.get_ssa(net)
36  inputs = []
37  for versioned_inputs, _ in ssa:
38  inputs += [name for name, _ in versioned_inputs]
39 
40  input_blobs = [blob_name for blob_name, version in
41  blob_versions.items()
42  if version == 0 and blob_name not in params]
43  # Blobs that are never used as an input to another layer,
44  # i.e. strictly output blobs.
45  output_blobs = [blob_name for blob_name, version in
46  blob_versions.items()
47  if version != 0 and blob_name not in inputs]
48 
49  for blob_ref in params:
50  blob_name = str(blob_ref)
51  blob = workspace.FetchBlob(blob_name)
52  init_net.op.extend(
53  [
54  core.CreateOperator(
55  "GivenTensorFill", [], [blob_name],
56  arg=[
57  utils.MakeArgument("shape", blob.shape),
58  utils.MakeArgument("values", blob)
59  ]
60  )
61  ]
62  )
63  # We have to make sure the blob exists in the namespace
64  # and we can do so with fake data. (Which is immediately overwritten
65  # by any typical usage)
66  for blob_name in input_blobs:
67  init_net.op.extend(
68  [
69  core.CreateOperator(
70  "GivenTensorFill", [], [blob_name],
71  arg=[
72  utils.MakeArgument("shape", [1, 1]),
73  utils.MakeArgument("values", [0.0])
74  ]
75  )
76  ]
77  )
78 
79  # Now we make input/output_blobs line up with what Predictor expects.
80  del predict_net.external_input[:]
81  predict_net.external_input.extend(input_blobs)
82  # For populating weights
83  predict_net.external_input.extend(proto.external_input)
84  # Ensure the output is also consistent with what we want
85  del predict_net.external_output[:]
86  predict_net.external_output.extend(output_blobs)
87  return init_net, predict_net