Caffe2 - Python API
A deep learning, cross platform ML framework
predictor_exporter.py
1 ## @package predictor_exporter
2 # Module caffe2.python.predictor.predictor_exporter
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import caffe2_pb2
9 from caffe2.proto import metanet_pb2
10 from caffe2.python import workspace, core, scope
11 from caffe2.python.predictor_constants import predictor_constants
12 import caffe2.python.predictor.serde as serde
14 from builtins import bytes
15 import collections
16 
17 
18 def get_predictor_exporter_helper(submodelNetName):
19  """ constracting stub for the PredictorExportMeta
20  Only used to construct names to subfields,
21  such as calling to predict_net_name
22  Args:
23  submodelNetName - name of the model
24  """
25  stub_net = core.Net(submodelNetName)
26  pred_meta = PredictorExportMeta(predict_net=stub_net,
27  parameters=[],
28  inputs=[],
29  outputs=[],
30  shapes=None,
31  name=submodelNetName,
32  extra_init_net=None)
33  return pred_meta
34 
35 
36 class PredictorExportMeta(collections.namedtuple(
37  'PredictorExportMeta',
38  'predict_net, parameters, inputs, outputs, shapes, name, \
39  extra_init_net, net_type, num_workers, trainer_prefix')):
40  """
41  Metadata to be used for serializaing a net.
42 
43  parameters, inputs, outputs could be either BlobReference or blob's names
44 
45  predict_net can be either core.Net, NetDef, PlanDef or object
46 
47  Override the named tuple to provide optional name parameter.
48  name will be used to identify multiple prediction nets.
49 
50  net_type is the type field in caffe2 NetDef - can be 'simple', 'dag', etc.
51 
52  num_workers specifies for net type 'dag' how many threads should run ops
53 
54  trainer_prefix specifies the type of trainer.
55  """
56  def __new__(
57  cls,
58  predict_net,
59  parameters,
60  inputs,
61  outputs,
62  shapes=None,
63  name="",
64  extra_init_net=None,
65  net_type=None,
66  num_workers=None,
67  trainer_prefix=None,
68  ):
69  inputs = [str(i) for i in inputs]
70  outputs = [str(o) for o in outputs]
71  assert len(set(inputs)) == len(inputs), (
72  "All inputs to the predictor should be unique")
73  parameters = [str(p) for p in parameters]
74  assert set(parameters).isdisjoint(inputs), (
75  "Parameters and inputs are required to be disjoint. "
76  "Intersection: {}".format(set(parameters).intersection(inputs)))
77  assert set(parameters).isdisjoint(outputs), (
78  "Parameters and outputs are required to be disjoint. "
79  "Intersection: {}".format(set(parameters).intersection(outputs)))
80  shapes = shapes or {}
81 
82  if isinstance(predict_net, (core.Net, core.Plan)):
83  predict_net = predict_net.Proto()
84 
85  assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
86  return super(PredictorExportMeta, cls).__new__(
87  cls, predict_net, parameters, inputs, outputs, shapes, name,
88  extra_init_net, net_type, num_workers, trainer_prefix)
89 
90  def inputs_name(self):
91  return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
92  self.name)
93 
94  def outputs_name(self):
95  return utils.get_comp_name(predictor_constants.OUTPUTS_BLOB_TYPE,
96  self.name)
97 
98  def parameters_name(self):
99  return utils.get_comp_name(predictor_constants.PARAMETERS_BLOB_TYPE,
100  self.name)
101 
102  def global_init_name(self):
103  return utils.get_comp_name(predictor_constants.GLOBAL_INIT_NET_TYPE,
104  self.name)
105 
106  def predict_init_name(self):
107  return utils.get_comp_name(predictor_constants.PREDICT_INIT_NET_TYPE,
108  self.name)
109 
110  def predict_net_name(self):
111  return utils.get_comp_name(predictor_constants.PREDICT_NET_TYPE,
112  self.name)
113 
114  def train_init_plan_name(self):
115  plan_name = utils.get_comp_name(predictor_constants.TRAIN_INIT_PLAN_TYPE,
116  self.name)
117  return self.trainer_prefix + '_' + plan_name \
118  if self.trainer_prefix else plan_name
119 
120  def train_plan_name(self):
121  plan_name = utils.get_comp_name(predictor_constants.TRAIN_PLAN_TYPE,
122  self.name)
123  return self.trainer_prefix + '_' + plan_name \
124  if self.trainer_prefix else plan_name
125 
126 
127 def prepare_prediction_net(filename, db_type, device_option=None):
128  '''
129  Helper function which loads all required blobs from the db
130  and returns prediction net ready to be used
131  '''
132  metanet_def = load_from_db(filename, db_type, device_option)
133 
134  global_init_net = utils.GetNet(
135  metanet_def, predictor_constants.GLOBAL_INIT_NET_TYPE)
136  workspace.RunNetOnce(global_init_net)
137 
138  predict_init_net = utils.GetNet(
139  metanet_def, predictor_constants.PREDICT_INIT_NET_TYPE)
140  workspace.RunNetOnce(predict_init_net)
141 
142  predict_net = core.Net(
143  utils.GetNet(metanet_def, predictor_constants.PREDICT_NET_TYPE))
144  workspace.CreateNet(predict_net)
145 
146  return predict_net
147 
148 
149 def _global_init_net(predictor_export_meta):
150  net = core.Net("global-init")
151  net.Load(
152  [predictor_constants.PREDICTOR_DBREADER],
153  predictor_export_meta.parameters)
154  net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
155  net.Proto().external_output.extend(predictor_export_meta.parameters)
156 
157  # Add the model_id in the predict_net to the global_init_net
158  utils.AddModelIdArg(predictor_export_meta, net.Proto())
159  return net.Proto()
160 
161 
162 def get_meta_net_def(predictor_export_meta, ws=None):
163  """
164  """
165 
166  ws = ws or workspace.C.Workspace.current
167  meta_net_def = metanet_pb2.MetaNetDef()
168 
169  # Predict net is the core network that we use.
170  utils.AddNet(meta_net_def, predictor_export_meta.predict_init_name(),
171  utils.create_predict_init_net(ws, predictor_export_meta))
172  utils.AddNet(meta_net_def, predictor_export_meta.global_init_name(),
173  _global_init_net(predictor_export_meta))
174  utils.AddNet(meta_net_def, predictor_export_meta.predict_net_name(),
175  utils.create_predict_net(predictor_export_meta))
176  utils.AddBlobs(meta_net_def, predictor_export_meta.parameters_name(),
177  predictor_export_meta.parameters)
178  utils.AddBlobs(meta_net_def, predictor_export_meta.inputs_name(),
179  predictor_export_meta.inputs)
180  utils.AddBlobs(meta_net_def, predictor_export_meta.outputs_name(),
181  predictor_export_meta.outputs)
182  return meta_net_def
183 
184 
185 def set_model_info(meta_net_def, project_str, model_class_str, version):
186  assert isinstance(meta_net_def, metanet_pb2.MetaNetDef)
187  meta_net_def.modelInfo.project = project_str
188  meta_net_def.modelInfo.modelClass = model_class_str
189  meta_net_def.modelInfo.version = version
190 
191 
192 def save_to_db(db_type, db_destination, predictor_export_meta, use_ideep = False):
193  meta_net_def = get_meta_net_def(predictor_export_meta)
194  device_type = caffe2_pb2.IDEEP if use_ideep else caffe2_pb2.CPU
195  with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
196  workspace.FeedBlob(
197  predictor_constants.META_NET_DEF,
198  serde.serialize_protobuf_struct(meta_net_def)
199  )
200 
201  blobs_to_save = [predictor_constants.META_NET_DEF] + \
202  predictor_export_meta.parameters
203  op = core.CreateOperator(
204  "Save",
205  blobs_to_save, [],
206  device_option = core.DeviceOption(device_type),
207  absolute_path=True,
208  db=db_destination, db_type=db_type)
209 
210  workspace.RunOperatorOnce(op)
211 
212 
213 def load_from_db(filename, db_type, device_option=None):
214  # global_init_net in meta_net_def will load parameters from
215  # predictor_constants.PREDICTOR_DBREADER
216  create_db = core.CreateOperator(
217  'CreateDB', [],
218  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
219  db=filename, db_type=db_type)
220  assert workspace.RunOperatorOnce(create_db), (
221  'Failed to create db {}'.format(filename))
222 
223  # predictor_constants.META_NET_DEF is always stored before the parameters
224  load_meta_net_def = core.CreateOperator(
225  'Load',
226  [core.BlobReference(predictor_constants.PREDICTOR_DBREADER)],
227  [core.BlobReference(predictor_constants.META_NET_DEF)])
228  assert workspace.RunOperatorOnce(load_meta_net_def)
229 
230  blob = workspace.FetchBlob(predictor_constants.META_NET_DEF)
231  meta_net_def = serde.deserialize_protobuf_struct(
232  blob if isinstance(blob, bytes)
233  else str(blob).encode('utf-8'),
234  metanet_pb2.MetaNetDef)
235 
236  if device_option is None:
237  device_option = scope.CurrentDeviceScope()
238 
239  if device_option is not None:
240  # Set the device options of all loaded blobs
241  for kv in meta_net_def.nets:
242  net = kv.value
243  for op in net.op:
244  op.device_option.CopyFrom(device_option)
245 
246  return meta_net_def