1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
15 This class modifies the net passed in by adding ops to compute norms for 19 blobs: list of blobs to compute norm for 20 logging_frequency: frequency for printing norms to logs 21 p: type of norm. Currently it supports p=1 or p=2 22 compute_averaged_norm: norm or averaged_norm (averaged_norm = norm/size) 25 def __init__(self, blobs, logging_frequency, p=2, compute_averaged_norm=False):
31 if compute_averaged_norm:
34 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
35 modify_output_record=
False):
42 blob_to_device = blob_to_device
or {}
43 for blob_name
in self.
_blobs:
45 if not net.BlobIsDefined(blob):
46 raise Exception(
'blob {0} is not defined in net {1}'.format(
48 if blob
in blob_to_device:
49 device = blob_to_device[blob]
53 with core.DeviceScope(device):
57 net.NextScopedBlob(prefix=blob +
'_float'),
58 to=core.DataType.FLOAT
61 cast_blob, norm_name, p=p, average=compute_averaged_norm
67 if modify_output_record:
71 if net.output_record()
is None:
72 net.set_output_record(
76 net.AppendOutputRecordField(
80 def field_name_suffix(self):