3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 def create_predict_net(predictor_export_meta):
13 Return the input prediction net. 16 net = core.Net(predictor_export_meta.predict_net.name
or "predict")
17 net.Proto().op.extend(predictor_export_meta.predict_net.op)
18 net.Proto().external_input.extend(
19 predictor_export_meta.inputs + predictor_export_meta.parameters)
20 net.Proto().external_output.extend(predictor_export_meta.outputs)
21 net.Proto().arg.extend(predictor_export_meta.predict_net.arg)
22 if predictor_export_meta.net_type
is not None:
23 net.Proto().type = predictor_export_meta.net_type
24 if predictor_export_meta.num_workers
is not None:
25 net.Proto().num_workers = predictor_export_meta.num_workers
29 def create_predict_init_net(ws, predictor_export_meta):
31 Return an initialization net that zero-fill all the input and 32 output blobs, using the shapes from the provided workspace. This is 33 necessary as there is no shape inference functionality in Caffe2. 35 net = core.Net(
"predict-init")
38 shape = predictor_export_meta.shapes.get(blob)
40 if blob
not in ws.blobs:
42 "{} not in workspace but needed for shape: {}".format(
45 shape = ws.blobs[blob].fetch().shape
50 with scope.EmptyDeviceScope():
51 net.ConstantFill([], blob, shape=shape, value=0.0)
53 external_blobs = predictor_export_meta.inputs + \
54 predictor_export_meta.outputs
55 for blob
in external_blobs:
58 net.Proto().external_input.extend(external_blobs)
59 if predictor_export_meta.extra_init_net:
60 net.AppendNet(predictor_export_meta.extra_init_net)
63 AddModelIdArg(predictor_export_meta, net.Proto())
68 def get_comp_name(string, name):
70 return string +
'_' + name
74 def _ProtoMapGet(field, key):
76 Given the key, get the value of the repeated field. 77 Helper function used by protobuf since it doesn't have map construct 85 def GetPlan(meta_net_def, key):
86 return _ProtoMapGet(meta_net_def.plans, key)
89 def GetPlanOriginal(meta_net_def, key):
90 return _ProtoMapGet(meta_net_def.plans, key)
93 def GetBlobs(meta_net_def, key):
94 blobs = _ProtoMapGet(meta_net_def.blobs, key)
100 def GetBlobsByTypePrefix(meta_net_def, blob_type_prefix):
102 for b
in meta_net_def.blobs:
103 if b.key.startswith(blob_type_prefix):
105 if blob
not in blob_map:
106 blob_map[blob] = len(blob_map)
107 return sorted(blob_map, key=
lambda blob: blob_map[blob])
110 def GetNet(meta_net_def, key):
111 return _ProtoMapGet(meta_net_def.nets, key)
114 def GetNetOriginal(meta_net_def, key):
115 return _ProtoMapGet(meta_net_def.nets, key)
118 def GetApplicationSpecificInfo(meta_net_def, key):
119 return _ProtoMapGet(meta_net_def.applicationSpecificInfo, key)
122 def AddBlobs(meta_net_def, blob_name, blob_def):
123 blobs = _ProtoMapGet(meta_net_def.blobs, blob_name)
125 blobs = meta_net_def.blobs.add()
126 blobs.key = blob_name
128 for blob
in blob_def:
132 def AddPlan(meta_net_def, plan_name, plan_def):
133 meta_net_def.plans.add(key=plan_name, value=plan_def)
136 def AddNet(meta_net_def, net_name, net_def):
137 meta_net_def.nets.add(key=net_name, value=net_def)
140 def GetArgumentByName(net_def, arg_name):
141 for arg
in net_def.arg:
142 if arg.name == arg_name:
147 def AddModelIdArg(meta_net_def, net_def):
148 """Takes the model_id from the predict_net of meta_net_def (if it is 149 populated) and adds it to the net_def passed in. This is intended to be 150 called on init_nets, as their model_id is not populated by default, but 151 should be the same as that of the predict_net 154 model_id = GetArgumentByName(meta_net_def.predict_net,
"model_id")
157 model_id = model_id.i
160 old_id = GetArgumentByName(net_def,
"model_id")
161 if old_id
is not None:
166 arg = net_def.arg.add()
167 arg.name =
"model_id"