3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 parameter_sharing_context,
19 from future.utils
import viewitems, viewkeys
20 from itertools
import chain
27 _known_working_ops = [
77 """A helper model so we can manange models more easily. It contains net def 78 and parameter storages. You can add an Operator yourself, e.g. 80 model = model_helper.ModelHelper(name="train_net") 81 # init your weight and bias as w and b 82 w = model.param_init_net.XavierFill(...) 83 b = model.param_init_net.ConstantFill(...) 84 fc1 = model.FC([input, w, b], output, **kwargs) 86 or you can use helper functions in brew module without manually 87 defining parameter initializations and operators. 89 model = model_helper.ModelHelper(name="train_net") 90 fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs) 94 def __init__(self, name=None, init_params=True, allow_not_known_ops=True,
95 skip_sparse_optim=
False, param_model=
None, arg_scope=
None):
96 self.
name = name
or "model" 99 if param_model
is not None:
102 self.
params = param_model.params
123 'cudnn_exhaustive_search':
False,
125 if arg_scope
is not None:
128 self._arg_scope.update(arg_scope)
137 def _infer_param_shape(self, param):
138 for op
in self.param_init_net.Proto().op:
139 if str(param)
in op.output:
141 if arg.name ==
"shape":
142 return list(arg.ints)
145 def _update_param_info_deprecated(self):
150 "Param %s must be a BlobReference!" % str(param))
151 self._param_info_deprecated.append(parameter_info.ParameterInfo(
156 info.grad = self.param_to_grad.get(info.name)
158 def _normalize_tags(self, tags):
160 return set(tags)
if isinstance(tags, list)
else set([tags])
164 Creates parameter with a given name and initializer. 166 If param_name is instance of BlobRefernce - then this blob will be used 167 to store parameter (no any logic will affect it's location). 169 If param_name is instance of a string type, then the final blob will 170 be created in the CurrentNameScope with the respect of all parameter 171 sharing logic, i.e. 'resolved_name_scope/param_name'. 173 Parameter sharing logic is going to override CurrentNameScope accoring 174 to the rules that are specified through ParameterSharing contexts, 175 all ParameterSharing contexts are applied recursively until there are no 176 extra overrides present, where on each step the best match will be 179 The following examples should clarify the way ParameterSharing logic 182 As an example if this function is called with parameter 'w': 183 a. Call from some scope 'global_scope' with no Parameter sharing: 185 b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}: 187 c. Call from scope 'scope_a', with override {'scope_a': ''}: 189 d. Call from scope 'scope_b/shared', with overrides 190 {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}: 192 d. Call from scope 'scope_b/unshared', with overrides 193 {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}: 200 param_name = str(param_name)
201 elif isinstance(param_name, six.string_types):
204 param_name = parameter_sharing_context.get_parameter_name(
207 raise TypeError(
"Unsupported type for param_name")
213 param_info = initializer.create_param(
218 optim_context = OptimizerContext.current()
220 if optim_context.has_optimizer(tag):
222 param_info.optimizer = optim_context.get_optimizer(tag)
223 if not param_info.optimizer
and optim_context.has_optimizer(DEFAULT_OPTIM):
224 param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)
226 reg_context = RegularizerContext.current()
227 param_info.regularizer = reg_context
233 return param_info.blob
235 def get_param_info(self, param):
237 "Param {} is not a BlobReference".format(param)
238 return self._parameters_info.get(param,
None)
242 def add_param_DEPRECATED(self, param, key=None, shape=None, length=None):
243 logging.warning(
"add_param method is DEPRECATED")
246 if key
is not None and self.net.input_record()
is not None:
247 idx = self.net.input_record().field_blobs().index(key)
248 key = self.net.input_record().field_names()[idx]
251 raise ValueError(
"Param %s must be a BlobReference!" % str(param))
252 self._param_info_deprecated.append(parameter_info.ParameterInfo(
261 def AddParameter(self, param, tags=None):
264 if parameter_info.ParameterTags.COMPUTED_PARAM
in tags:
265 self._computed_params.append(param)
267 self.params.append(param)
269 if parameter_info.ParameterTags.WEIGHT
in tags:
270 self.weights.append(param)
271 if parameter_info.ParameterTags.BIAS
in tags:
272 self.biases.append(param)
275 def _NormalizeNamescope(namescope):
276 if namescope
is None:
277 return scope.CurrentNameScope()
278 elif namescope ==
'' or namescope.endswith(scope._NAMESCOPE_SEPARATOR):
281 return namescope + scope._NAMESCOPE_SEPARATOR
285 Returns the params in current namescope 287 namescope = ModelHelper._NormalizeNamescope(namescope)
292 return [p
for p
in self.
params if 293 p.GetNameScope().startswith(namescope)]
296 return self.net.Proto()
299 return self.param_init_net.Proto()
301 def RunAllOnGPU(self, *args, **kwargs):
302 self.param_init_net.RunAllOnGPU(*args, **kwargs)
303 self.net.RunAllOnGPU(*args, **kwargs)
305 def CreateDB(self, blob_out, db, db_type, **kwargs):
306 dbreader = self.param_init_net.CreateDB(
307 [], blob_out, db=db, db_type=db_type, **kwargs)
310 def AddGradientOperators(self, *args, **kwargs):
312 raise RuntimeError(
"You cannot run AddGradientOperators twice.")
316 self.
grad_map = self.net.AddGradientOperators(*args, **kwargs)
321 for param, grad
in self.param_to_grad.items():
324 param_info.grad = grad
336 Given a list of parameters returns a dict from a parameter 337 to a corresponding gradient 342 raise RuntimeError(
"You need to run AddGradientOperators first.")
347 param_to_grad[p] = self.
grad_map[str(p)]
352 Returns a map for param => grad. 353 If params is not specified, all parameters will be considered. 356 raise RuntimeError(
"Need to call AddGradientOperators first")
363 self.
get_param_info(param)
for param, grad
in viewitems(param_to_grad)
366 not isinstance(grad, core.GradientSlice)
372 Check for duplicate params 374 params_list = [str(p)
for p
in self.
params]
375 params_set = set(params_list)
378 if len(params_set) != len(params_list):
379 params_list = sorted(params_list)
380 for j, p
in enumerate(params_list):
381 if j > 0
and params_list[j - 1] == p:
389 assert dupes == [],
"Duplicate params: {}".format(dupes)
393 Returns the computed params in current namescope. 'Computed params' 394 are such parameters that are not optimized via gradient descent but are 395 directly computed from data, such as the running mean and variance 396 of Spatial Batch Normalization. 398 namescope = ModelHelper._NormalizeNamescope(namescope)
404 if p.GetNameScope().startswith(namescope)]
406 def GetAllParams(self, namescope=None):
410 self, unused_blob_in, blob_out, batch_size, db, db_type, **kwargs
412 """TensorProtosDBInput.""" 413 assert len(unused_blob_in) == 0, \
414 """You cannot pass reader to model_helper.TensorProtosDBInput. 415 Use model.net.TensorProtosDBInput instead to create the op.""" 417 return helpers.db_input.db_input(
418 self, blob_out, batch_size, db, db_type, **kwargs)
420 def GetDevices(self):
422 "Use data_parallel_model to run model on multiple GPUs." 426 """Catch-all for all other operators, mostly those without params.""" 427 if op_type.startswith(
'__'):
428 raise AttributeError(op_type)
430 if not core.IsOperator(op_type):
431 raise AttributeError(
432 'Method ' + op_type +
' is not a registered operator.' +
434 ','.join(workspace.C.nearby_opnames(op_type)) +
']' 436 if op_type
not in _known_working_ops:
438 raise AttributeError(
439 "Operator {} is not known to be safe".format(op_type))
441 logging.warning(
"You are creating an op that the ModelHelper " 442 "does not recognize: {}.".format(op_type))
443 return self.net.__getattr__(op_type)
446 return sorted(set(chain(
448 viewkeys(self.__dict__),
452 def GetCompleteNet(self):
453 r""" Return param_init_net + net Net. 455 'core.Net' containing param_init_net and net 457 new_net = self.param_init_net.Clone(
458 self.
name +
"_complete_net", keep_schema=
True)
460 for op
in new_net.Proto().op:
461 op.debug_info = op.debug_info +
"/param_init_net" 462 new_net.AppendNet(self.
net)
464 if self.net.Proto().HasField(
"type"):
465 new_net.Proto().type = self.net.Proto().type
468 def ConstructInitTrainNetfromNet(self, net):
469 r""" construct init net and train net from complete_net 471 net: 'core.Net' containing param_init_net and train net 475 for idx, op
in enumerate(net.Proto().op):
476 if op.debug_info.endswith(
"/param_init_net"):
477 param_op_mask.append(idx)
479 train_op_mask.append(idx)
482 net.Name() +
"/generated_param_init_net",
484 op_id_mask=param_op_mask,
485 update_external_list=
True,
487 self.
net = net.Clone(
488 net.Name() +
"/generated_net",
490 op_id_mask=train_op_mask,
491 update_external_list=
True,
495 def ExtractPredictorNet(
501 disabled_inputs=
None,
504 Takes a model net for training and returns a net which can be 505 used for prediction. For example, all gradient operators and 506 input operators are removed. 507 @param net_proto protobuf of the net you want to process (net.Proto()) 508 @param input_blobs list/set of blob names that are the inputs of predictor 509 @param output_blobs list/set of blob names that are outputs of predictor 510 @param device optional device option that is assigned 511 @param renames dictionary of blob name to a new name (optional) 512 @param disabled_inputs optional set of blobs that are 'switched off'. This 513 will cause branches with those blobs as inputs to be removed 515 predict_net =
core.Net(net_proto.name +
"_predict")
516 predict_proto = predict_net.Proto()
518 orig_external_inputs = set(net_proto.external_input)
519 orig_external_outputs = set(net_proto.external_output)
520 input_blobs = {str(b)
for b
in input_blobs}
521 known_blobs = set(orig_external_inputs).union(input_blobs)
522 output_blobs = {str(b)
for b
in output_blobs}
523 external_inputs = set(input_blobs)
524 external_outputs = set(output_blobs)
529 if disabled_inputs
is not None:
530 known_blobs = known_blobs - set(disabled_inputs)
532 ops = list(net_proto.op)
536 first_op_with_input = min(
538 j
for j
in range(len(ops))
539 if input_blobs.intersection(ops[j].input)
and ops[j].type !=
544 raise Exception(
"No ops with input={}".format(input_blobs))
546 last_op_with_output = max(
548 j
for j
in range(len(ops))
549 if output_blobs.intersection(ops[j].output)
553 raise Exception(
"No ops with output={}".format(output_blobs))
559 if arg.name ==
"is_test" and arg.i == 0:
561 "An operator had is_test=0, did you try to extract a " +
562 "predictor from a train model (instead of test model)?" +
563 " Op was: {}".format(str(op))
566 def rename_list(proto_list):
568 new_list = proto_list[:]
569 for j, b
in enumerate(new_list):
571 new_list[j] = renames[b]
574 proto_list.extend(new_list)
578 for op
in ops[first_op_with_input:(last_op_with_output + 1)]:
579 if known_blobs.issuperset(op.input):
584 if op.type ==
'RecurrentNetwork':
586 if arg.name ==
'backward_step_net':
587 arg.ClearField(str(
'n'))
588 elif arg.name ==
'step_net':
589 for step_op
in arg.n.op:
590 rename_list(step_op.input)
591 rename_list(step_op.output)
592 if device
is not None:
593 step_op.device_option.device_type = device.device_type
594 step_op.device_option.device_id = device.device_id
596 rename_list(arg.n.external_input)
597 rename_list(arg.n.external_output)
600 external_inputs.update(
601 set(arg.n.external_input).intersection(
606 if device
is not None:
607 op.device_option.device_type = device.device_type
608 op.device_option.device_id = device.device_id
610 predict_proto.op.extend([op])
611 known_blobs.update(op.output)
612 external_inputs.update(
613 set(op.input).intersection(orig_external_inputs)
615 external_outputs.update(
616 set(op.output).intersection(orig_external_outputs)
621 "Op {} had unknown inputs: {}".format(
622 op.type, set(op.input).difference(known_blobs)
628 predict_proto.external_input.extend(external_inputs)
629 predict_proto.external_output.extend(external_outputs)
631 rename_list(predict_proto.external_input)
632 rename_list(predict_proto.external_output)
634 renamed_input_blobs = []
635 for b
in input_blobs:
637 renamed_input_blobs.append(renames[b])
639 renamed_input_blobs.append(b)
641 for op
in predict_proto.op:
642 rename_list(op.input)
643 rename_list(op.output)
645 return predict_net, list(
646 set(predict_proto.external_input) - set(renamed_input_blobs)
def AddParameter(self, param, tags=None)
def _infer_param_shape(self, param)
def create_param(self, param_name, shape, initializer, tags=None)
def GetOptimizationParamInfo(self, params=None)
def TensorProtosDBInput(self, unused_blob_in, blob_out, batch_size, db, db_type, kwargs)
def _update_param_info_deprecated(self)
def get_param_to_grad(self, params)
def GetComputedParams(self, namescope=None)
def __getattr__(self, op_type)
def get_param_info(self, param)
def _normalize_tags(self, tags)
def GetParams(self, namescope=None, top_scope=False)