1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
14 This class modifies the net passed in by adding ops to compute histogram for 18 blobs: list of blobs to compute histogram for 19 logging_frequency: frequency for printing 20 lower_bound: left boundary of histogram values 21 upper_bound: right boundary of histogram values 22 num_buckets: number of buckets to use in [lower_bound, upper_bound) 23 accumulate: boolean to output accumulate or per-batch histogram 26 def __init__(self, blobs, logging_frequency, num_buckets=30,
27 lower_bound=0.0, upper_bound=1.0, accumulate=
False):
38 "num_buckets need to be greater than 0, got {}".format(num_buckets))
42 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
43 modify_output_record=
False):
44 for blob_name
in self.
_blobs:
46 if not net.BlobIsDefined(blob):
47 raise Exception(
'blob {0} is not defined in net {1}'.format(
50 blob_float = net.Cast(blob, net.NextScopedBlob(prefix=blob +
51 '_float'), to=core.DataType.FLOAT)
52 curr_hist, acc_hist = net.AccumulateHistogram(
54 [net.NextScopedBlob(prefix=blob +
'_curr_hist'),
55 net.NextScopedBlob(prefix=blob +
'_acc_hist')],
63 net.NextScopedBlob(prefix=blob +
'_cast_hist'),
64 to=core.DataType.FLOAT)
68 net.NextScopedBlob(prefix=blob +
'_cast_hist'),
69 to=core.DataType.FLOAT)
71 normalized_hist = net.NormalizeL1(
79 if modify_output_record:
84 if net.output_record()
is None:
85 net.set_output_record(
89 net.AppendOutputRecordField(
93 def field_name_suffix(self):