Caffe2 - Python API
A deep learning, cross platform ML framework
compute_norm_for_blobs.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import core, schema, muji
7 from caffe2.python.modeling.net_modifier import NetModifier
8 
9 
10 import numpy as np
11 
12 
14  """
15  This class modifies the net passed in by adding ops to compute norms for
16  certain blobs.
17 
18  Args:
19  blobs: list of blobs to compute norm for
20  logging_frequency: frequency for printing norms to logs
21  p: type of norm. Currently it supports p=1 or p=2
22  compute_averaged_norm: norm or averaged_norm (averaged_norm = norm/size)
23  """
24 
25  def __init__(self, blobs, logging_frequency, p=2, compute_averaged_norm=False):
26  self._blobs = blobs
27  self._logging_frequency = logging_frequency
28  self._p = p
29  self._compute_averaged_norm = compute_averaged_norm
30  self._field_name_suffix = '_l{}_norm'.format(p)
31  if compute_averaged_norm:
32  self._field_name_suffix = '_averaged' + self._field_name_suffix
33 
34  def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
35  modify_output_record=False):
36 
37  p = self._p
38  compute_averaged_norm = self._compute_averaged_norm
39 
40  CPU = muji.OnCPU()
41  # if given, blob_to_device is a map from blob to device_option
42  blob_to_device = blob_to_device or {}
43  for blob_name in self._blobs:
44  blob = core.BlobReference(blob_name)
45  if not net.BlobIsDefined(blob):
46  raise Exception('blob {0} is not defined in net {1}'.format(
47  blob, net.Name()))
48  if blob in blob_to_device:
49  device = blob_to_device[blob]
50  else:
51  device = CPU
52 
53  with core.DeviceScope(device):
54  norm_name = net.NextScopedBlob(prefix=blob + self._field_name_suffix)
55  cast_blob = net.Cast(
56  blob,
57  net.NextScopedBlob(prefix=blob + '_float'),
58  to=core.DataType.FLOAT
59  )
60  norm = net.LpNorm(
61  cast_blob, norm_name, p=p, average=compute_averaged_norm
62  )
63 
64  if self._logging_frequency >= 1:
65  net.Print(norm, [], every_n=self._logging_frequency)
66 
67  if modify_output_record:
68  output_field_name = str(blob) + self._field_name_suffix
69  output_scalar = schema.Scalar((np.float, (1,)), norm)
70 
71  if net.output_record() is None:
72  net.set_output_record(
73  schema.Struct((output_field_name, output_scalar))
74  )
75  else:
76  net.AppendOutputRecordField(
77  output_field_name,
78  output_scalar)
79 
80  def field_name_suffix(self):
81  return self._field_name_suffix