3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
9 from caffe2.proto
import metanet_pb2
14 from builtins
import bytes
18 def get_predictor_exporter_helper(submodelNetName):
19 """ constracting stub for the PredictorExportMeta 20 Only used to construct names to subfields, 21 such as calling to predict_net_name 23 submodelNetName - name of the model 25 stub_net = core.Net(submodelNetName)
26 pred_meta = PredictorExportMeta(predict_net=stub_net,
36 class PredictorExportMeta(collections.namedtuple(
37 'PredictorExportMeta',
38 'predict_net, parameters, inputs, outputs, shapes, name, \ 39 extra_init_net, net_type, num_workers, trainer_prefix')):
41 Metadata to be used for serializaing a net. 43 parameters, inputs, outputs could be either BlobReference or blob's names 45 predict_net can be either core.Net, NetDef, PlanDef or object 47 Override the named tuple to provide optional name parameter. 48 name will be used to identify multiple prediction nets. 50 net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc. 52 num_workers specifies for net type 'dag' how many threads should run ops 54 trainer_prefix specifies the type of trainer. 69 inputs = [str(i)
for i
in inputs]
70 outputs = [str(o)
for o
in outputs]
71 assert len(set(inputs)) == len(inputs), (
72 "All inputs to the predictor should be unique")
73 parameters = [str(p)
for p
in parameters]
74 assert set(parameters).isdisjoint(inputs), (
75 "Parameters and inputs are required to be disjoint. " 76 "Intersection: {}".format(set(parameters).intersection(inputs)))
77 assert set(parameters).isdisjoint(outputs), (
78 "Parameters and outputs are required to be disjoint. " 79 "Intersection: {}".format(set(parameters).intersection(outputs)))
83 predict_net = predict_net.Proto()
85 assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
86 return super(PredictorExportMeta, cls).__new__(
87 cls, predict_net, parameters, inputs, outputs, shapes, name,
88 extra_init_net, net_type, num_workers, trainer_prefix)
90 def inputs_name(self):
91 return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
94 def outputs_name(self):
95 return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
98 def parameters_name(self):
99 return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
102 def global_init_name(self):
103 return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
106 def predict_init_name(self):
107 return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
110 def predict_net_name(self):
111 return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
114 def train_init_plan_name(self):
115 plan_name = utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
117 return self.trainer_prefix +
'_' + plan_name \
118 if self.trainer_prefix
else plan_name
120 def train_plan_name(self):
121 plan_name = utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
123 return self.trainer_prefix +
'_' + plan_name \
124 if self.trainer_prefix
else plan_name
127 def prepare_prediction_net(filename, db_type, device_option=None):
129 Helper function which loads all required blobs from the db 130 and returns prediction net ready to be used 132 metanet_def = load_from_db(filename, db_type, device_option)
134 global_init_net = utils.GetNet(
135 metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
136 workspace.RunNetOnce(global_init_net)
138 predict_init_net = utils.GetNet(
139 metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
140 workspace.RunNetOnce(predict_init_net)
143 utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
144 workspace.CreateNet(predict_net)
149 def _global_init_net(predictor_export_meta):
152 [predictor_constants.PREDICTOR_DBREADER],
153 predictor_export_meta.parameters)
154 net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
155 net.Proto().external_output.extend(predictor_export_meta.parameters)
158 utils.AddModelIdArg(predictor_export_meta, net.Proto())
162 def get_meta_net_def(predictor_export_meta, ws=None):
166 ws = ws
or workspace.C.Workspace.current
167 meta_net_def = metanet_pb2.MetaNetDef()
170 utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
171 utils.create_predict_init_net(ws, predictor_export_meta))
172 utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
173 _global_init_net(predictor_export_meta))
174 utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
175 utils.create_predict_net(predictor_export_meta))
176 utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
177 predictor_export_meta.parameters)
178 utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
179 predictor_export_meta.inputs)
180 utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
181 predictor_export_meta.outputs)
185 def set_model_info(meta_net_def, project_str, model_class_str, version):
186 assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
187 meta_net_def.modelInfo.project = project_str
188 meta_net_def.modelInfo.modelClass = model_class_str
189 meta_net_def.modelInfo.version = version
192 def save_to_db(db_type, db_destination, predictor_export_meta, use_ideep = False):
193 meta_net_def = get_meta_net_def(predictor_export_meta)
194 device_type = caffe2_pb2.IDEEP
if use_ideep
else caffe2_pb2.CPU
195 with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
197 predictor_constants.META_NET_DEF,
198 serde.serialize_protobuf_struct(meta_net_def)
201 blobs_to_save = [predictor_constants.META_NET_DEF] + \
202 predictor_export_meta.parameters
203 op = core.CreateOperator(
206 device_option = core.DeviceOption(device_type),
208 db=db_destination, db_type=db_type)
210 workspace.RunOperatorOnce(op)
213 def load_from_db(filename, db_type, device_option=None):
216 create_db = core.CreateOperator(
219 db=filename, db_type=db_type)
220 assert workspace.RunOperatorOnce(create_db), (
221 'Failed to create db {}'.format(filename))
224 load_meta_net_def = core.CreateOperator(
228 assert workspace.RunOperatorOnce(load_meta_net_def)
230 blob = workspace.FetchBlob(predictor_constants.META_NET_DEF)
231 meta_net_def = serde.deserialize_protobuf_struct(
232 blob
if isinstance(blob, bytes)
233 else str(blob).encode(
'utf-8'),
234 metanet_pb2.MetaNetDef)
236 if device_option
is None:
237 device_option = scope.CurrentDeviceScope()
239 if device_option
is not None:
241 for kv
in meta_net_def.nets:
244 op.device_option.CopyFrom(device_option)