16 from __future__
import absolute_import
17 from __future__
import division
18 from __future__
import print_function
19 from __future__
import unicode_literals
29 This class modifies the net passed in by adding ops to get a certain entry 33 blobs: list of blobs to get entry from 34 logging_frequency: frequency for printing entry values to logs 35 i1, i2: the first, second dimension of the blob. (currently, we assume 36 the blobs to be 2-dimensional blobs). When i2 = -1, print all entries 40 def __init__(self, blobs, logging_frequency, i1=0, i2=0):
46 else '_{0}_all'.format(i1)
48 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
49 modify_output_record=
False):
51 i1, i2 = [self.
_i1, self.
_i2]
53 raise ValueError(
'index is out of range')
55 for blob_name
in self.
_blobs:
57 if not net.BlobIsDefined(blob):
58 raise Exception(
'blob {0} is not defined in net {1}'.format(
61 blob_i1 = net.Slice([blob], starts=[i1, 0], ends=[i1 + 1, -1])
63 blob_i1_i2 = net.Copy([blob_i1],
64 [net.NextScopedBlob(prefix=blob +
'_{0}_all'.format(i1))])
66 blob_i1_i2 = net.Slice([blob_i1],
67 net.NextScopedBlob(prefix=blob +
'_{0}_{1}'.format(i1, i2)),
68 starts=[0, i2], ends=[-1, i2 + 1])
73 if modify_output_record:
77 if net.output_record()
is None:
78 net.set_output_record(
82 net.AppendOutputRecordField(output_field_name, output_scalar)
84 def field_name_suffix(self):