Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_py_utils.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package predictor_py_utils
17 # Module caffe2.python.predictor.predictor_py_utils
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, scope
24 
25 
26 def create_predict_net(predictor_export_meta):
27  """
28  Return the input prediction net.
29  """
30  # Construct a new net to clear the existing settings.
31  net = core.Net(predictor_export_meta.predict_net.name or "predict")
32  net.Proto().op.extend(predictor_export_meta.predict_net.op)
33  net.Proto().external_input.extend(
34  predictor_export_meta.inputs + predictor_export_meta.parameters)
35  net.Proto().external_output.extend(predictor_export_meta.outputs)
36  net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
37  if predictor_export_meta.net_type is not None:
38  net.Proto().type = predictor_export_meta.net_type
39  if predictor_export_meta.num_workers is not None:
40  net.Proto().num_workers = predictor_export_meta.num_workers
41  return net.Proto()
42 
43 
44 def create_predict_init_net(ws, predictor_export_meta):
45  """
46  Return an initialization net that zero-fill all the input and
47  output blobs, using the shapes from the provided workspace. This is
48  necessary as there is no shape inference functionality in Caffe2.
49  """
50  net = core.Net("predict-init")
51 
52  def zero_fill(blob):
53  shape = predictor_export_meta.shapes.get(blob)
54  if shape is None:
55  if blob not in ws.blobs:
56  raise Exception(
57  "{} not in workspace but needed for shape: {}".format(
58  blob, ws.blobs))
59 
60  shape = ws.blobs[blob].fetch().shape
61 
62  # Explicitly null-out the scope so users (e.g. PredictorGPU)
63  # can control (at a Net-global level) the DeviceOption of
64  # these filling operators.
65  with scope.EmptyDeviceScope():
66  net.ConstantFill([], blob, shape=shape, value=0.0)
67 
68  external_blobs = predictor_export_meta.inputs + \
69  predictor_export_meta.outputs
70  for blob in external_blobs:
71  zero_fill(blob)
72 
73  net.Proto().external_input.extend(external_blobs)
74  if predictor_export_meta.extra_init_net:
75  net.AppendNet(predictor_export_meta.extra_init_net)
76 
77  # Add the model_id in the predict_net to the init_net
78  AddModelIdArg(predictor_export_meta, net.Proto())
79 
80  return net.Proto()
81 
82 
83 def get_comp_name(string, name):
84  if name:
85  return string + '_' + name
86  return string
87 
88 
89 def _ProtoMapGet(field, key):
90  '''
91  Given the key, get the value of the repeated field.
92  Helper function used by protobuf since it doesn't have map construct
93  '''
94  for v in field:
95  if (v.key == key):
96  return v.value
97  return None
98 
99 
100 def GetPlan(meta_net_def, key):
101  return _ProtoMapGet(meta_net_def.plans, key)
102 
103 
104 def GetPlanOriginal(meta_net_def, key):
105  return _ProtoMapGet(meta_net_def.plans, key)
106 
107 
108 def GetBlobs(meta_net_def, key):
109  blobs = _ProtoMapGet(meta_net_def.blobs, key)
110  if blobs is None:
111  return []
112  return blobs
113 
114 
115 def GetNet(meta_net_def, key):
116  return _ProtoMapGet(meta_net_def.nets, key)
117 
118 
119 def GetNetOriginal(meta_net_def, key):
120  return _ProtoMapGet(meta_net_def.nets, key)
121 
122 
123 def GetApplicationSpecificInfo(meta_net_def, key):
124  return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
125 
126 
127 def AddBlobs(meta_net_def, blob_name, blob_def):
128  blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
129  if blobs is None:
130  blobs = meta_net_def.blobs.add()
131  blobs.key = blob_name
132  blobs = blobs.value
133  for blob in blob_def:
134  blobs.append(blob)
135 
136 
137 def AddPlan(meta_net_def, plan_name, plan_def):
138  meta_net_def.plans.add(key=plan_name, value=plan_def)
139 
140 
141 def AddNet(meta_net_def, net_name, net_def):
142  meta_net_def.nets.add(key=net_name, value=net_def)
143 
144 
145 def GetArgumentByName(net_def, arg_name):
146  for arg in net_def.arg:
147  if arg.name == arg_name:
148  return arg
149  return None
150 
151 
152 def AddModelIdArg(meta_net_def, net_def):
153  """Takes the model_id from the predict_net of meta_net_def (if it is
154  populated) and adds it to the net_def passed in. This is intended to be
155  called on init_nets, as their model_id is not populated by default, but
156  should be the same as that of the predict_net
157  """
158  # Get model_id from the predict_net, assuming it's an integer
159  model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
160  if model_id is None:
161  return
162  model_id = model_id.i
163 
164  # If there's another model_id on the net, replace it with the new one
165  old_id = GetArgumentByName(net_def, "model_id")
166  if old_id is not None:
167  old_id.i = model_id
168  return
169 
170  # Add as an integer argument, this is also assumed above
171  arg = net_def.arg.add()
172  arg.name = "model_id"
173  arg.i = model_id