Caffe2 - Python API
A deep learning, cross platform ML framework
visualize.py
1 ## @package visualize
2 # Module caffe2.python.visualize
3 """Functions that could be used to visualize Tensors.
4 
5 This is adapted from the old-time iceberk package that Yangqing wrote... Oh gold
6 memories. Before decaf and caffe. Why iceberk? Because I was at Berkeley,
7 bears are vegetarian, and iceberg lettuce has layers of leaves.
8 
9 (This joke is so lame.)
10 """
11 
12 import numpy as np
13 from matplotlib import cm, pyplot
14 
15 
16 def ChannelFirst(arr):
17  """Convert a HWC array to CHW."""
18  ndim = arr.ndim
19  return arr.swapaxes(ndim - 1, ndim - 2).swapaxes(ndim - 2, ndim - 3)
20 
21 
22 def ChannelLast(arr):
23  """Convert a CHW array to HWC."""
24  ndim = arr.ndim
25  return arr.swapaxes(ndim - 3, ndim - 2).swapaxes(ndim - 2, ndim - 1)
26 
27 
28 class PatchVisualizer(object):
29  """PatchVisualizer visualizes patches.
30  """
31 
32  def __init__(self, gap=1):
33  self.gap = gap
34 
35  def ShowSingle(self, patch, cmap=None):
36  """Visualizes one single patch.
37 
38  The input patch could be a vector (in which case we try to infer the shape
39  of the patch), a 2-D matrix, or a 3-D matrix whose 3rd dimension has 3
40  channels.
41  """
42  if len(patch.shape) == 1:
43  patch = patch.reshape(self.get_patch_shape(patch))
44  elif len(patch.shape) > 2 and patch.shape[2] != 3:
45  raise ValueError("The input patch shape isn't correct.")
46  # determine color
47  if len(patch.shape) == 2 and cmap is None:
48  cmap = cm.gray
49  pyplot.imshow(patch, cmap=cmap)
50  return patch
51 
52  def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean):
53  """Visualize multiple patches.
54 
55  In the passed in patches matrix, each row is a patch, in the shape of either
56  n*n, n*n*1 or n*n*3, either in a flattened format (so patches would be a
57  2-D array), or a multi-dimensional tensor. We will try our best to figure
58  out automatically the patch size.
59  """
60  num_patches = patches.shape[0]
61  if ncols is None:
62  ncols = int(np.ceil(np.sqrt(num_patches)))
63  nrows = int(np.ceil(num_patches / float(ncols)))
64  if len(patches.shape) == 2:
65  patches = patches.reshape(
66  (patches.shape[0], ) + self.get_patch_shape(patches[0])
67  )
68  patch_size_expand = np.array(patches.shape[1:3]) + self.gap
69  image_size = patch_size_expand * np.array([nrows, ncols]) - self.gap
70  if len(patches.shape) == 4:
71  if patches.shape[3] == 1:
72  # gray patches
73  patches = patches.reshape(patches.shape[:-1])
74  image_shape = tuple(image_size)
75  if cmap is None:
76  cmap = cm.gray
77  elif patches.shape[3] == 3:
78  # color patches
79  image_shape = tuple(image_size) + (3, )
80  else:
81  raise ValueError("The input patch shape isn't expected.")
82  else:
83  image_shape = tuple(image_size)
84  if cmap is None:
85  cmap = cm.gray
86  image = np.ones(image_shape) * bg_func(patches)
87  for pid in range(num_patches):
88  row = pid // ncols * patch_size_expand[0]
89  col = pid % ncols * patch_size_expand[1]
90  image[row:row+patches.shape[1], col:col+patches.shape[2]] = \
91  patches[pid]
92  pyplot.imshow(image, cmap=cmap, interpolation='nearest')
93  pyplot.axis('off')
94  return image
95 
96  def ShowImages(self, patches, *args, **kwargs):
97  """Similar to ShowMultiple, but always normalize the values between 0 and 1
98  for better visualization of image-type data.
99  """
100  patches = patches - np.min(patches)
101  patches /= np.max(patches) + np.finfo(np.float64).eps
102  return self.ShowMultiple(patches, *args, **kwargs)
103 
104  def ShowChannels(self, patch, cmap=None, bg_func=np.mean):
105  """ This function shows the channels of a patch.
106 
107  The incoming patch should have shape [w, h, num_channels], and each channel
108  will be visualized as a separate gray patch.
109  """
110  if len(patch.shape) != 3:
111  raise ValueError("The input patch shape isn't correct.")
112  patch_reordered = np.swapaxes(patch.T, 1, 2)
113  return self.ShowMultiple(patch_reordered, cmap=cmap, bg_func=bg_func)
114 
115  def get_patch_shape(self, patch):
116  """Gets the shape of a single patch.
117 
118  Basically it tries to interprete the patch as a square, and also check if it
119  is in color (3 channels)
120  """
121  edgeLen = np.sqrt(patch.size)
122  if edgeLen != np.floor(edgeLen):
123  # we are given color patches
124  edgeLen = np.sqrt(patch.size / 3.)
125  if edgeLen != np.floor(edgeLen):
126  raise ValueError("I can't figure out the patch shape.")
127  return (edgeLen, edgeLen, 3)
128  else:
129  edgeLen = int(edgeLen)
130  return (edgeLen, edgeLen)
131 
132 
133 _default_visualizer = PatchVisualizer()
134 """Utility functions that directly point to functions in the default visualizer.
135 
136 These functions don't return anything, so you won't see annoying printouts of
137 the visualized images. If you want to save the images for example, you should
138 explicitly instantiate a patch visualizer, and call those functions.
139 """
140 
141 
142 class NHWC(object):
143  @staticmethod
144  def ShowSingle(*args, **kwargs):
145  _default_visualizer.ShowSingle(*args, **kwargs)
146 
147  @staticmethod
148  def ShowMultiple(*args, **kwargs):
149  _default_visualizer.ShowMultiple(*args, **kwargs)
150 
151  @staticmethod
152  def ShowImages(*args, **kwargs):
153  _default_visualizer.ShowImages(*args, **kwargs)
154 
155  @staticmethod
156  def ShowChannels(*args, **kwargs):
157  _default_visualizer.ShowChannels(*args, **kwargs)
158 
159 
160 class NCHW(object):
161  @staticmethod
162  def ShowSingle(patch, *args, **kwargs):
163  _default_visualizer.ShowSingle(ChannelLast(patch), *args, **kwargs)
164 
165  @staticmethod
166  def ShowMultiple(patch, *args, **kwargs):
167  _default_visualizer.ShowMultiple(ChannelLast(patch), *args, **kwargs)
168 
169  @staticmethod
170  def ShowImages(patch, *args, **kwargs):
171  _default_visualizer.ShowImages(ChannelLast(patch), *args, **kwargs)
172 
173  @staticmethod
174  def ShowChannels(patch, *args, **kwargs):
175  _default_visualizer.ShowChannels(ChannelLast(patch), *args, **kwargs)
def ShowChannels(self, patch, cmap=None, bg_func=np.mean)
Definition: visualize.py:104
def ShowSingle(self, patch, cmap=None)
Definition: visualize.py:35
def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean)
Definition: visualize.py:52
def ShowImages(self, patches, args, kwargs)
Definition: visualize.py:96