Caffe2 - Python API
A deep learning, cross platform ML framework
visualize.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package visualize
17 # Module caffe2.python.visualize
18 """Functions that could be used to visualize Tensors.
19 
20 This is adapted from the old-time iceberk package that Yangqing wrote... Oh gold
21 memories. Before decaf and caffe. Why iceberk? Because I was at Berkeley,
22 bears are vegetarian, and iceberg lettuce has layers of leaves.
23 
24 (This joke is so lame.)
25 """
26 
27 import numpy as np
28 from matplotlib import cm, pyplot
29 
30 
31 def ChannelFirst(arr):
32  """Convert a HWC array to CHW."""
33  ndim = arr.ndim
34  return arr.swapaxes(ndim - 1, ndim - 2).swapaxes(ndim - 2, ndim - 3)
35 
36 
37 def ChannelLast(arr):
38  """Convert a CHW array to HWC."""
39  ndim = arr.ndim
40  return arr.swapaxes(ndim - 3, ndim - 2).swapaxes(ndim - 2, ndim - 1)
41 
42 
43 class PatchVisualizer(object):
44  """PatchVisualizer visualizes patches.
45  """
46 
47  def __init__(self, gap=1):
48  self.gap = gap
49 
50  def ShowSingle(self, patch, cmap=None):
51  """Visualizes one single patch.
52 
53  The input patch could be a vector (in which case we try to infer the shape
54  of the patch), a 2-D matrix, or a 3-D matrix whose 3rd dimension has 3
55  channels.
56  """
57  if len(patch.shape) == 1:
58  patch = patch.reshape(self.get_patch_shape(patch))
59  elif len(patch.shape) > 2 and patch.shape[2] != 3:
60  raise ValueError("The input patch shape isn't correct.")
61  # determine color
62  if len(patch.shape) == 2 and cmap is None:
63  cmap = cm.gray
64  pyplot.imshow(patch, cmap=cmap)
65  return patch
66 
67  def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean):
68  """Visualize multiple patches.
69 
70  In the passed in patches matrix, each row is a patch, in the shape of either
71  n*n, n*n*1 or n*n*3, either in a flattened format (so patches would be a
72  2-D array), or a multi-dimensional tensor. We will try our best to figure
73  out automatically the patch size.
74  """
75  num_patches = patches.shape[0]
76  if ncols is None:
77  ncols = int(np.ceil(np.sqrt(num_patches)))
78  nrows = int(np.ceil(num_patches / float(ncols)))
79  if len(patches.shape) == 2:
80  patches = patches.reshape(
81  (patches.shape[0], ) + self.get_patch_shape(patches[0])
82  )
83  patch_size_expand = np.array(patches.shape[1:3]) + self.gap
84  image_size = patch_size_expand * np.array([nrows, ncols]) - self.gap
85  if len(patches.shape) == 4:
86  if patches.shape[3] == 1:
87  # gray patches
88  patches = patches.reshape(patches.shape[:-1])
89  image_shape = tuple(image_size)
90  if cmap is None:
91  cmap = cm.gray
92  elif patches.shape[3] == 3:
93  # color patches
94  image_shape = tuple(image_size) + (3, )
95  else:
96  raise ValueError("The input patch shape isn't expected.")
97  else:
98  image_shape = tuple(image_size)
99  if cmap is None:
100  cmap = cm.gray
101  image = np.ones(image_shape) * bg_func(patches)
102  for pid in range(num_patches):
103  row = pid // ncols * patch_size_expand[0]
104  col = pid % ncols * patch_size_expand[1]
105  image[row:row+patches.shape[1], col:col+patches.shape[2]] = \
106  patches[pid]
107  pyplot.imshow(image, cmap=cmap, interpolation='nearest')
108  pyplot.axis('off')
109  return image
110 
111  def ShowImages(self, patches, *args, **kwargs):
112  """Similar to ShowMultiple, but always normalize the values between 0 and 1
113  for better visualization of image-type data.
114  """
115  patches = patches - np.min(patches)
116  patches /= np.max(patches) + np.finfo(np.float64).eps
117  return self.ShowMultiple(patches, *args, **kwargs)
118 
119  def ShowChannels(self, patch, cmap=None, bg_func=np.mean):
120  """ This function shows the channels of a patch.
121 
122  The incoming patch should have shape [w, h, num_channels], and each channel
123  will be visualized as a separate gray patch.
124  """
125  if len(patch.shape) != 3:
126  raise ValueError("The input patch shape isn't correct.")
127  patch_reordered = np.swapaxes(patch.T, 1, 2)
128  return self.ShowMultiple(patch_reordered, cmap=cmap, bg_func=bg_func)
129 
130  def get_patch_shape(self, patch):
131  """Gets the shape of a single patch.
132 
133  Basically it tries to interprete the patch as a square, and also check if it
134  is in color (3 channels)
135  """
136  edgeLen = np.sqrt(patch.size)
137  if edgeLen != np.floor(edgeLen):
138  # we are given color patches
139  edgeLen = np.sqrt(patch.size / 3.)
140  if edgeLen != np.floor(edgeLen):
141  raise ValueError("I can't figure out the patch shape.")
142  return (edgeLen, edgeLen, 3)
143  else:
144  edgeLen = int(edgeLen)
145  return (edgeLen, edgeLen)
146 
147 
148 _default_visualizer = PatchVisualizer()
149 """Utility functions that directly point to functions in the default visualizer.
150 
151 These functions don't return anything, so you won't see annoying printouts of
152 the visualized images. If you want to save the images for example, you should
153 explicitly instantiate a patch visualizer, and call those functions.
154 """
155 
156 
157 class NHWC(object):
158  @staticmethod
159  def ShowSingle(*args, **kwargs):
160  _default_visualizer.ShowSingle(*args, **kwargs)
161 
162  @staticmethod
163  def ShowMultiple(*args, **kwargs):
164  _default_visualizer.ShowMultiple(*args, **kwargs)
165 
166  @staticmethod
167  def ShowImages(*args, **kwargs):
168  _default_visualizer.ShowImages(*args, **kwargs)
169 
170  @staticmethod
171  def ShowChannels(*args, **kwargs):
172  _default_visualizer.ShowChannels(*args, **kwargs)
173 
174 
175 class NCHW(object):
176  @staticmethod
177  def ShowSingle(patch, *args, **kwargs):
178  _default_visualizer.ShowSingle(ChannelLast(patch), *args, **kwargs)
179 
180  @staticmethod
181  def ShowMultiple(patch, *args, **kwargs):
182  _default_visualizer.ShowMultiple(ChannelLast(patch), *args, **kwargs)
183 
184  @staticmethod
185  def ShowImages(patch, *args, **kwargs):
186  _default_visualizer.ShowImages(ChannelLast(patch), *args, **kwargs)
187 
188  @staticmethod
189  def ShowChannels(patch, *args, **kwargs):
190  _default_visualizer.ShowChannels(ChannelLast(patch), *args, **kwargs)
def ShowChannels(self, patch, cmap=None, bg_func=np.mean)
Definition: visualize.py:119
def ShowSingle(self, patch, cmap=None)
Definition: visualize.py:50
def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean)
Definition: visualize.py:67
def ShowImages(self, patches, args, kwargs)
Definition: visualize.py:111