Caffe2 - Python API
A deep learning, cross platform ML framework
compute_histogram_for_blobs.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import core, schema
7 from caffe2.python.modeling.net_modifier import NetModifier
8 
9 import numpy as np
10 
11 
13  """
14  This class modifies the net passed in by adding ops to compute histogram for
15  certain blobs.
16 
17  Args:
18  blobs: list of blobs to compute histogram for
19  logging_frequency: frequency for printing
20  lower_bound: left boundary of histogram values
21  upper_bound: right boundary of histogram values
22  num_buckets: number of buckets to use in [lower_bound, upper_bound)
23  accumulate: boolean to output accumulate or per-batch histogram
24  """
25 
26  def __init__(self, blobs, logging_frequency, num_buckets=30,
27  lower_bound=0.0, upper_bound=1.0, accumulate=False):
28  self._blobs = blobs
29  self._logging_frequency = logging_frequency
30  self._accumulate = accumulate
31  if self._accumulate:
32  self._field_name_suffix = '_acc_normalized_hist'
33  else:
34  self._field_name_suffix = '_curr_normalized_hist'
35 
36  self._num_buckets = int(num_buckets)
37  assert self._num_buckets > 0, (
38  "num_buckets need to be greater than 0, got {}".format(num_buckets))
39  self._lower_bound = float(lower_bound)
40  self._upper_bound = float(upper_bound)
41 
42  def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
43  modify_output_record=False):
44  for blob_name in self._blobs:
45  blob = core.BlobReference(blob_name)
46  if not net.BlobIsDefined(blob):
47  raise Exception('blob {0} is not defined in net {1}'.format(
48  blob, net.Name()))
49 
50  blob_float = net.Cast(blob, net.NextScopedBlob(prefix=blob +
51  '_float'), to=core.DataType.FLOAT)
52  curr_hist, acc_hist = net.AccumulateHistogram(
53  [blob_float],
54  [net.NextScopedBlob(prefix=blob + '_curr_hist'),
55  net.NextScopedBlob(prefix=blob + '_acc_hist')],
56  num_buckets=self._num_buckets,
57  lower_bound=self._lower_bound,
58  upper_bound=self._upper_bound)
59 
60  if self._accumulate:
61  hist = net.Cast(
62  acc_hist,
63  net.NextScopedBlob(prefix=blob + '_cast_hist'),
64  to=core.DataType.FLOAT)
65  else:
66  hist = net.Cast(
67  curr_hist,
68  net.NextScopedBlob(prefix=blob + '_cast_hist'),
69  to=core.DataType.FLOAT)
70 
71  normalized_hist = net.NormalizeL1(
72  hist,
73  net.NextScopedBlob(prefix=blob + self._field_name_suffix)
74  )
75 
76  if self._logging_frequency >= 1:
77  net.Print(normalized_hist, [], every_n=self._logging_frequency)
78 
79  if modify_output_record:
80  output_field_name = str(blob) + self._field_name_suffix
81  output_scalar = schema.Scalar((np.float32, (self._num_buckets + 2,)),
82  normalized_hist)
83 
84  if net.output_record() is None:
85  net.set_output_record(
86  schema.Struct((output_field_name, output_scalar))
87  )
88  else:
89  net.AppendOutputRecordField(
90  output_field_name,
91  output_scalar)
92 
93  def field_name_suffix(self):
94  return self._field_name_suffix