3 """Functions that could be used to visualize Tensors. 5 This is adapted from the old-time iceberk package that Yangqing wrote... Oh gold 6 memories. Before decaf and caffe. Why iceberk? Because I was at Berkeley, 7 bears are vegetarian, and iceberg lettuce has layers of leaves. 9 (This joke is so lame.) 13 from matplotlib
import cm, pyplot
16 def ChannelFirst(arr):
17 """Convert a HWC array to CHW.""" 19 return arr.swapaxes(ndim - 1, ndim - 2).swapaxes(ndim - 2, ndim - 3)
23 """Convert a CHW array to HWC.""" 25 return arr.swapaxes(ndim - 3, ndim - 2).swapaxes(ndim - 2, ndim - 1)
29 """PatchVisualizer visualizes patches. 32 def __init__(self, gap=1):
36 """Visualizes one single patch. 38 The input patch could be a vector (in which case we try to infer the shape 39 of the patch), a 2-D matrix, or a 3-D matrix whose 3rd dimension has 3 42 if len(patch.shape) == 1:
44 elif len(patch.shape) > 2
and patch.shape[2] != 3:
45 raise ValueError(
"The input patch shape isn't correct.")
47 if len(patch.shape) == 2
and cmap
is None:
49 pyplot.imshow(patch, cmap=cmap)
52 def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean):
53 """Visualize multiple patches. 55 In the passed in patches matrix, each row is a patch, in the shape of either 56 n*n, n*n*1 or n*n*3, either in a flattened format (so patches would be a 57 2-D array), or a multi-dimensional tensor. We will try our best to figure 58 out automatically the patch size. 60 num_patches = patches.shape[0]
62 ncols = int(np.ceil(np.sqrt(num_patches)))
63 nrows = int(np.ceil(num_patches / float(ncols)))
64 if len(patches.shape) == 2:
65 patches = patches.reshape(
68 patch_size_expand = np.array(patches.shape[1:3]) + self.
gap 69 image_size = patch_size_expand * np.array([nrows, ncols]) - self.
gap 70 if len(patches.shape) == 4:
71 if patches.shape[3] == 1:
73 patches = patches.reshape(patches.shape[:-1])
74 image_shape = tuple(image_size)
77 elif patches.shape[3] == 3:
79 image_shape = tuple(image_size) + (3, )
81 raise ValueError(
"The input patch shape isn't expected.")
83 image_shape = tuple(image_size)
86 image = np.ones(image_shape) * bg_func(patches)
87 for pid
in range(num_patches):
88 row = pid // ncols * patch_size_expand[0]
89 col = pid % ncols * patch_size_expand[1]
90 image[row:row+patches.shape[1], col:col+patches.shape[2]] = \
92 pyplot.imshow(image, cmap=cmap, interpolation=
'nearest')
97 """Similar to ShowMultiple, but always normalize the values between 0 and 1 98 for better visualization of image-type data. 100 patches = patches - np.min(patches)
101 patches /= np.max(patches) + np.finfo(np.float64).eps
105 """ This function shows the channels of a patch. 107 The incoming patch should have shape [w, h, num_channels], and each channel 108 will be visualized as a separate gray patch. 110 if len(patch.shape) != 3:
111 raise ValueError(
"The input patch shape isn't correct.")
112 patch_reordered = np.swapaxes(patch.T, 1, 2)
113 return self.
ShowMultiple(patch_reordered, cmap=cmap, bg_func=bg_func)
116 """Gets the shape of a single patch. 118 Basically it tries to interprete the patch as a square, and also check if it 119 is in color (3 channels) 121 edgeLen = np.sqrt(patch.size)
122 if edgeLen != np.floor(edgeLen):
124 edgeLen = np.sqrt(patch.size / 3.)
125 if edgeLen != np.floor(edgeLen):
126 raise ValueError(
"I can't figure out the patch shape.")
127 return (edgeLen, edgeLen, 3)
129 edgeLen = int(edgeLen)
130 return (edgeLen, edgeLen)
134 """Utility functions that directly point to functions in the default visualizer. 136 These functions don't return anything, so you won't see annoying printouts of 137 the visualized images. If you want to save the images for example, you should 138 explicitly instantiate a patch visualizer, and call those functions. 144 def ShowSingle(*args, **kwargs):
145 _default_visualizer.ShowSingle(*args, **kwargs)
148 def ShowMultiple(*args, **kwargs):
149 _default_visualizer.ShowMultiple(*args, **kwargs)
152 def ShowImages(*args, **kwargs):
153 _default_visualizer.ShowImages(*args, **kwargs)
156 def ShowChannels(*args, **kwargs):
157 _default_visualizer.ShowChannels(*args, **kwargs)
162 def ShowSingle(patch, *args, **kwargs):
163 _default_visualizer.ShowSingle(ChannelLast(patch), *args, **kwargs)
166 def ShowMultiple(patch, *args, **kwargs):
167 _default_visualizer.ShowMultiple(ChannelLast(patch), *args, **kwargs)
170 def ShowImages(patch, *args, **kwargs):
171 _default_visualizer.ShowImages(ChannelLast(patch), *args, **kwargs)
174 def ShowChannels(patch, *args, **kwargs):
175 _default_visualizer.ShowChannels(ChannelLast(patch), *args, **kwargs)
def ShowChannels(self, patch, cmap=None, bg_func=np.mean)
def ShowSingle(self, patch, cmap=None)
def get_patch_shape(self, patch)
def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean)
def ShowImages(self, patches, args, kwargs)