3 from __future__ 
import absolute_import
     4 from __future__ 
import division
     5 from __future__ 
import print_function
     6 from __future__ 
import unicode_literals
    11     parameter_sharing_context,
    19 from future.utils 
import viewitems, viewkeys
    20 from itertools 
import chain
    27 _known_working_ops = [
    77     """A helper model so we can manange models more easily. It contains net def    78     and parameter storages. You can add an Operator yourself, e.g.    80         model = model_helper.ModelHelper(name="train_net")    81         # init your weight and bias as w and b    82         w = model.param_init_net.XavierFill(...)    83         b = model.param_init_net.ConstantFill(...)    84         fc1 = model.FC([input, w, b], output, **kwargs)    86     or you can use helper functions in brew module without manually    87     defining parameter initializations and operators.    89         model = model_helper.ModelHelper(name="train_net")    90         fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs)    94     def __init__(self, name=None, init_params=True, allow_not_known_ops=True,
    95                  skip_sparse_optim=
False, param_model=
None, arg_scope=
None):
    96         self.
name = name 
or "model"    99         if param_model 
is not None:
   102             self.
params = param_model.params
   123             'cudnn_exhaustive_search': 
False,
   125         if arg_scope 
is not None:
   128             self._arg_scope.update(arg_scope)
   137     def _infer_param_shape(self, param):
   138         for op 
in self.param_init_net.Proto().op:
   139             if str(param) 
in op.output:
   141                     if arg.name == 
"shape":
   142                         return list(arg.ints)
   145     def _update_param_info_deprecated(self):
   150                     "Param %s must be a BlobReference!" % str(param))
   151             self._param_info_deprecated.append(parameter_info.ParameterInfo(
   156             info.grad = self.param_to_grad.get(info.name)
   158     def _normalize_tags(self, tags):
   160         return set(tags) 
if isinstance(tags, list) 
else set([tags])
   164         Creates parameter with a given name and initializer.   166         If param_name is instance of BlobRefernce - then this blob will be used   167         to store parameter (no any logic will affect it's location).   169         If param_name is instance of a string type, then the final blob will   170         be created in the CurrentNameScope with the respect of all parameter   171         sharing logic, i.e. 'resolved_name_scope/param_name'.   173         Parameter sharing logic is going to override CurrentNameScope accoring   174         to the rules that are specified through ParameterSharing contexts,   175         all ParameterSharing contexts are applied recursively until there are no   176         extra overrides present, where on each step the best match will be   179         The following examples should clarify the way ParameterSharing logic   182         As an example if this function is called with parameter 'w':   183         a. Call from some scope 'global_scope' with no Parameter sharing:   185         b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}:   187         c. Call from scope 'scope_a', with override {'scope_a': ''}:   189         d. Call from scope 'scope_b/shared', with overrides   190           {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:   192         d. Call from scope 'scope_b/unshared', with overrides   193           {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:   200             param_name = str(param_name)
   201         elif isinstance(param_name, six.string_types):
   204             param_name = parameter_sharing_context.get_parameter_name(
   207             raise TypeError(
"Unsupported type for param_name")
   213         param_info = initializer.create_param(
   218         optim_context = OptimizerContext.current()
   220             if optim_context.has_optimizer(tag):
   222                 param_info.optimizer = optim_context.get_optimizer(tag)
   223         if not param_info.optimizer 
and optim_context.has_optimizer(DEFAULT_OPTIM):
   224             param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)
   226         reg_context = RegularizerContext.current()
   227         param_info.regularizer = reg_context
   233         return param_info.blob
   235     def get_param_info(self, param):
   237             "Param {} is not a BlobReference".format(param)
   238         return self._parameters_info.get(param, 
None)
   242     def add_param_DEPRECATED(self, param, key=None, shape=None, length=None):
   243         logging.warning(
"add_param method is DEPRECATED")
   246         if key 
is not None and self.net.input_record() 
is not None:
   247             idx = self.net.input_record().field_blobs().index(key)
   248             key = self.net.input_record().field_names()[idx]
   251             raise ValueError(
"Param %s must be a BlobReference!" % str(param))
   252         self._param_info_deprecated.append(parameter_info.ParameterInfo(
   261     def AddParameter(self, param, tags=None):
   264         if parameter_info.ParameterTags.COMPUTED_PARAM 
in tags:
   265             self._computed_params.append(param)
   267             self.params.append(param)
   269         if parameter_info.ParameterTags.WEIGHT 
in tags:
   270             self.weights.append(param)
   271         if parameter_info.ParameterTags.BIAS 
in tags:
   272             self.biases.append(param)
   275     def _NormalizeNamescope(namescope):
   276         if namescope 
is None:
   277             return scope.CurrentNameScope()
   278         elif namescope == 
'' or namescope.endswith(scope._NAMESCOPE_SEPARATOR):
   281             return namescope + scope._NAMESCOPE_SEPARATOR
   285         Returns the params in current namescope   287         namescope = ModelHelper._NormalizeNamescope(namescope)
   292             return [p 
for p 
in self.
params if   293                     p.GetNameScope().startswith(namescope)]
   296         return self.net.Proto()
   299         return self.param_init_net.Proto()
   301     def RunAllOnGPU(self, *args, **kwargs):
   302         self.param_init_net.RunAllOnGPU(*args, **kwargs)
   303         self.net.RunAllOnGPU(*args, **kwargs)
   305     def CreateDB(self, blob_out, db, db_type, **kwargs):
   306         dbreader = self.param_init_net.CreateDB(
   307             [], blob_out, db=db, db_type=db_type, **kwargs)
   310     def AddGradientOperators(self, *args, **kwargs):
   312             raise RuntimeError(
"You cannot run AddGradientOperators twice.")
   316         self.
grad_map = self.net.AddGradientOperators(*args, **kwargs)
   321         for param, grad 
in self.param_to_grad.items():
   324                 param_info.grad = grad
   336         Given a list of parameters returns a dict from a parameter   337         to a corresponding gradient   342             raise RuntimeError(
"You need to run AddGradientOperators first.")
   347                 param_to_grad[p] = self.
grad_map[str(p)]
   352         Returns a map for param => grad.   353         If params is not specified, all parameters will be considered.   356             raise RuntimeError(
"Need to call AddGradientOperators first")
   363             self.
get_param_info(param) 
for param, grad 
in viewitems(param_to_grad)
   366                 not isinstance(grad, core.GradientSlice)
   372         Check for duplicate params   374         params_list = [str(p) 
for p 
in self.
params]
   375         params_set = set(params_list)
   378         if len(params_set) != len(params_list):
   379             params_list = sorted(params_list)
   380             for j, p 
in enumerate(params_list):
   381                 if j > 0 
and params_list[j - 1] == p:
   389         assert dupes == [], 
"Duplicate params: {}".format(dupes)
   393         Returns the computed params in current namescope. 'Computed params'   394         are such parameters that are not optimized via gradient descent but are   395         directly computed from data, such as the running mean and variance   396         of Spatial Batch Normalization.   398         namescope = ModelHelper._NormalizeNamescope(namescope)
   404                     if p.GetNameScope().startswith(namescope)]
   406     def GetAllParams(self, namescope=None):
   410         self, unused_blob_in, blob_out, batch_size, db, db_type, **kwargs
   412         """TensorProtosDBInput."""   413         assert len(unused_blob_in) == 0, \
   414             """You cannot pass reader to model_helper.TensorProtosDBInput.   415                Use model.net.TensorProtosDBInput instead to create the op."""   417         return helpers.db_input.db_input(
   418             self, blob_out, batch_size, db, db_type, **kwargs)
   420     def GetDevices(self):
   422             "Use data_parallel_model to run model on multiple GPUs."   426         """Catch-all for all other operators, mostly those without params."""   427         if op_type.startswith(
'__'):
   428             raise AttributeError(op_type)
   430         if not core.IsOperator(op_type):
   431             raise AttributeError(
   432                 'Method ' + op_type + 
' is not a registered operator.' +
   434                 ','.join(workspace.C.nearby_opnames(op_type)) + 
']'   436         if op_type 
not in _known_working_ops:
   438                 raise AttributeError(
   439                     "Operator {} is not known to be safe".format(op_type))
   441             logging.warning(
"You are creating an op that the ModelHelper "   442                             "does not recognize: {}.".format(op_type))
   443         return self.net.__getattr__(op_type)
   446         return sorted(set(chain(
   448             viewkeys(self.__dict__),
   452     def GetCompleteNet(self):
   453         r""" Return param_init_net + net Net.   455           'core.Net' containing param_init_net and net   457         new_net = self.param_init_net.Clone(
   458             self.
name + 
"_complete_net", keep_schema=
True)
   460         for op 
in new_net.Proto().op:
   461             op.debug_info = op.debug_info + 
"/param_init_net"   462         new_net.AppendNet(self.
net)
   464         if self.net.Proto().HasField(
"type"):
   465             new_net.Proto().type = self.net.Proto().type
   468     def ConstructInitTrainNetfromNet(self, net):
   469         r""" construct init net and train net from complete_net   471           net: 'core.Net' containing param_init_net and train net   475         for idx, op 
in enumerate(net.Proto().op):
   476             if op.debug_info.endswith(
"/param_init_net"):
   477                 param_op_mask.append(idx)
   479                 train_op_mask.append(idx)
   482             net.Name() + 
"/generated_param_init_net",
   484             op_id_mask=param_op_mask,
   485             update_external_list=
True,
   487         self.
net = net.Clone(
   488             net.Name() + 
"/generated_net",
   490             op_id_mask=train_op_mask,
   491             update_external_list=
True,
   495 def ExtractPredictorNet(
   501     disabled_inputs=
None,
   504     Takes a model net for training and returns a net which can be   505     used for prediction. For example, all gradient operators and   506     input operators are removed.   507     @param net_proto protobuf of the net you want to process (net.Proto())   508     @param input_blobs list/set of blob names that are the inputs of predictor   509     @param output_blobs list/set of blob names that are outputs of predictor   510     @param device optional device option that is assigned   511     @param renames dictionary of blob name to a new name (optional)   512     @param disabled_inputs optional set of blobs that are 'switched off'. This   513                 will cause branches with those blobs as inputs to be removed   515     predict_net = 
core.Net(net_proto.name + 
"_predict")
   516     predict_proto = predict_net.Proto()
   518     orig_external_inputs = set(net_proto.external_input)
   519     orig_external_outputs = set(net_proto.external_output)
   520     input_blobs = {str(b) 
for b 
in input_blobs}
   521     known_blobs = set(orig_external_inputs).union(input_blobs)
   522     output_blobs = {str(b) 
for b 
in output_blobs}
   523     external_inputs = set(input_blobs)
   524     external_outputs = set(output_blobs)
   529     if disabled_inputs 
is not None:
   530         known_blobs = known_blobs - set(disabled_inputs)
   532     ops = list(net_proto.op)
   536         first_op_with_input = min(
   538                 j 
for j 
in range(len(ops))
   539                 if input_blobs.intersection(ops[j].input) 
and ops[j].type !=
   544         raise Exception(
"No ops with input={}".format(input_blobs))
   546         last_op_with_output = max(
   548                 j 
for j 
in range(len(ops))
   549                 if output_blobs.intersection(ops[j].output)
   553         raise Exception(
"No ops with output={}".format(output_blobs))
   559             if arg.name == 
"is_test" and arg.i == 0:
   561                     "An operator had is_test=0, did you try to extract a " +
   562                     "predictor from a train model (instead of test model)?" +
   563                     " Op was: {}".format(str(op))
   566     def rename_list(proto_list):
   568         new_list = proto_list[:]
   569         for j, b 
in enumerate(new_list):
   571                 new_list[j] = renames[b]
   574         proto_list.extend(new_list)
   578     for op 
in ops[first_op_with_input:(last_op_with_output + 1)]:
   579         if known_blobs.issuperset(op.input):
   584             if op.type == 
'RecurrentNetwork':
   586                     if arg.name == 
'backward_step_net':
   587                         arg.ClearField(str(
'n'))
   588                     elif arg.name == 
'step_net':
   589                         for step_op 
in arg.n.op:
   590                             rename_list(step_op.input)
   591                             rename_list(step_op.output)
   592                             if device 
is not None:
   593                                 step_op.device_option.device_type = device.device_type
   594                                 step_op.device_option.device_id = device.device_id
   596                         rename_list(arg.n.external_input)
   597                         rename_list(arg.n.external_output)
   600                         external_inputs.update(
   601                             set(arg.n.external_input).intersection(
   606             if device 
is not None:
   607                 op.device_option.device_type = device.device_type
   608                 op.device_option.device_id = device.device_id
   610             predict_proto.op.extend([op])
   611             known_blobs.update(op.output)
   612             external_inputs.update(
   613                 set(op.input).intersection(orig_external_inputs)
   615             external_outputs.update(
   616                 set(op.output).intersection(orig_external_outputs)
   621                 "Op {} had unknown inputs: {}".format(
   622                     op.type, set(op.input).difference(known_blobs)
   628     predict_proto.external_input.extend(external_inputs)
   629     predict_proto.external_output.extend(external_outputs)
   631     rename_list(predict_proto.external_input)
   632     rename_list(predict_proto.external_output)
   634     renamed_input_blobs = []
   635     for b 
in input_blobs:
   637             renamed_input_blobs.append(renames[b])
   639             renamed_input_blobs.append(b)
   641     for op 
in predict_proto.op:
   642         rename_list(op.input)
   643         rename_list(op.output)
   645     return predict_net, list(
   646         set(predict_proto.external_input) - set(renamed_input_blobs)
 
def AddParameter(self, param, tags=None)
 
def _infer_param_shape(self, param)
 
def create_param(self, param_name, shape, initializer, tags=None)
 
def GetOptimizationParamInfo(self, params=None)
 
def TensorProtosDBInput(self, unused_blob_in, blob_out, batch_size, db, db_type, kwargs)
 
def _update_param_info_deprecated(self)
 
def get_param_to_grad(self, params)
 
def GetComputedParams(self, namescope=None)
 
def __getattr__(self, op_type)
 
def get_param_info(self, param)
 
def _normalize_tags(self, tags)
 
def GetParams(self, namescope=None, top_scope=False)