Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_py_utils.py
1 ## @package predictor_py_utils
2 # Module caffe2.python.predictor.predictor_py_utils
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, scope
9 
10 
11 def create_predict_net(predictor_export_meta):
12  """
13  Return the input prediction net.
14  """
15  # Construct a new net to clear the existing settings.
16  net = core.Net(predictor_export_meta.predict_net.name or "predict")
17  net.Proto().op.extend(predictor_export_meta.predict_net.op)
18  net.Proto().external_input.extend(
19  predictor_export_meta.inputs + predictor_export_meta.parameters)
20  net.Proto().external_output.extend(predictor_export_meta.outputs)
21  net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
22  if predictor_export_meta.net_type is not None:
23  net.Proto().type = predictor_export_meta.net_type
24  if predictor_export_meta.num_workers is not None:
25  net.Proto().num_workers = predictor_export_meta.num_workers
26  return net.Proto()
27 
28 
29 def create_predict_init_net(ws, predictor_export_meta):
30  """
31  Return an initialization net that zero-fill all the input and
32  output blobs, using the shapes from the provided workspace. This is
33  necessary as there is no shape inference functionality in Caffe2.
34  """
35  net = core.Net("predict-init")
36 
37  def zero_fill(blob):
38  shape = predictor_export_meta.shapes.get(blob)
39  if shape is None:
40  if blob not in ws.blobs:
41  raise Exception(
42  "{} not in workspace but needed for shape: {}".format(
43  blob, ws.blobs))
44 
45  shape = ws.blobs[blob].fetch().shape
46 
47  # Explicitly null-out the scope so users (e.g. PredictorGPU)
48  # can control (at a Net-global level) the DeviceOption of
49  # these filling operators.
50  with scope.EmptyDeviceScope():
51  net.ConstantFill([], blob, shape=shape, value=0.0)
52 
53  external_blobs = predictor_export_meta.inputs + \
54  predictor_export_meta.outputs
55  for blob in external_blobs:
56  zero_fill(blob)
57 
58  net.Proto().external_input.extend(external_blobs)
59  if predictor_export_meta.extra_init_net:
60  net.AppendNet(predictor_export_meta.extra_init_net)
61 
62  # Add the model_id in the predict_net to the init_net
63  AddModelIdArg(predictor_export_meta, net.Proto())
64 
65  return net.Proto()
66 
67 
68 def get_comp_name(string, name):
69  if name:
70  return string + '_' + name
71  return string
72 
73 
74 def _ProtoMapGet(field, key):
75  '''
76  Given the key, get the value of the repeated field.
77  Helper function used by protobuf since it doesn't have map construct
78  '''
79  for v in field:
80  if (v.key == key):
81  return v.value
82  return None
83 
84 
85 def GetPlan(meta_net_def, key):
86  return _ProtoMapGet(meta_net_def.plans, key)
87 
88 
89 def GetPlanOriginal(meta_net_def, key):
90  return _ProtoMapGet(meta_net_def.plans, key)
91 
92 
93 def GetBlobs(meta_net_def, key):
94  blobs = _ProtoMapGet(meta_net_def.blobs, key)
95  if blobs is None:
96  return []
97  return blobs
98 
99 
100 def GetBlobsByTypePrefix(meta_net_def, blob_type_prefix):
101  blob_map = {}
102  for b in meta_net_def.blobs:
103  if b.key.startswith(blob_type_prefix):
104  for blob in b.value:
105  if blob not in blob_map:
106  blob_map[blob] = len(blob_map)
107  return sorted(blob_map, key=lambda blob: blob_map[blob])
108 
109 
110 def GetNet(meta_net_def, key):
111  return _ProtoMapGet(meta_net_def.nets, key)
112 
113 
114 def GetNetOriginal(meta_net_def, key):
115  return _ProtoMapGet(meta_net_def.nets, key)
116 
117 
118 def GetApplicationSpecificInfo(meta_net_def, key):
119  return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
120 
121 
122 def AddBlobs(meta_net_def, blob_name, blob_def):
123  blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
124  if blobs is None:
125  blobs = meta_net_def.blobs.add()
126  blobs.key = blob_name
127  blobs = blobs.value
128  for blob in blob_def:
129  blobs.append(blob)
130 
131 
132 def AddPlan(meta_net_def, plan_name, plan_def):
133  meta_net_def.plans.add(key=plan_name, value=plan_def)
134 
135 
136 def AddNet(meta_net_def, net_name, net_def):
137  meta_net_def.nets.add(key=net_name, value=net_def)
138 
139 
140 def GetArgumentByName(net_def, arg_name):
141  for arg in net_def.arg:
142  if arg.name == arg_name:
143  return arg
144  return None
145 
146 
147 def AddModelIdArg(meta_net_def, net_def):
148  """Takes the model_id from the predict_net of meta_net_def (if it is
149  populated) and adds it to the net_def passed in. This is intended to be
150  called on init_nets, as their model_id is not populated by default, but
151  should be the same as that of the predict_net
152  """
153  # Get model_id from the predict_net, assuming it's an integer
154  model_id = GetArgumentByName(meta_net_def.predict_net, "model_id")
155  if model_id is None:
156  return
157  model_id = model_id.i
158 
159  # If there's another model_id on the net, replace it with the new one
160  old_id = GetArgumentByName(net_def, "model_id")
161  if old_id is not None:
162  old_id.i = model_id
163  return
164 
165  # Add as an integer argument, this is also assumed above
166  arg = net_def.arg.add()
167  arg.name = "model_id"
168  arg.i = model_id