1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
14 This class modifies the net passed in by adding ops to compute statistics 15 for certain blobs. For each blob in the list, its min, max, mean and standard 16 deviation will be computed. 19 blobs: list of blobs to compute norm for 20 logging_frequency: frequency for printing norms to logs 23 def __init__(self, blobs, logging_frequency):
28 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
29 modify_output_record=
False):
31 for blob_name
in self.
_blobs:
33 if not net.BlobIsDefined(blob):
34 raise Exception(
'blob {0} is not defined in net {1}'.format(
37 cast_blob = net.Cast(blob, to=core.DataType.FLOAT)
39 stats = net.Summarize(cast_blob, stats_name, to_file=0)
42 if modify_output_record:
46 if net.output_record()
is None:
47 net.set_output_record(
51 net.AppendOutputRecordField(
55 def field_name_suffix(self):