Caffe2 - Python API
A deep learning, cross platform ML framework
helpers.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package helpers
17 # Module caffe2.python.tutorials.helpers
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 import numpy as np
23 import skimage.io
24 import skimage.transform
25 
26 
27 def crop_center(img, cropx, cropy):
28  y, x, c = img.shape
29  startx = x // 2 - (cropx // 2)
30  starty = y // 2 - (cropy // 2)
31  return img[starty:starty + cropy, startx:startx + cropx]
32 
33 
34 def rescale(img, input_height, input_width):
35  # print("Original image shape:" + str(img.shape) + " --> it should be in H, W, C!")
36  # print("Model's input shape is %dx%d") % (input_height, input_width)
37  aspect = img.shape[1] / float(img.shape[0])
38  # print("Orginal aspect ratio: " + str(aspect))
39  if(aspect > 1):
40  # landscape orientation - wide image
41  res = int(aspect * input_height)
42  imgScaled = skimage.transform.resize(
43  img,
44  (input_width, res),
45  preserve_range=False)
46  if(aspect < 1):
47  # portrait orientation - tall image
48  res = int(input_width / aspect)
49  imgScaled = skimage.transform.resize(
50  img,
51  (res, input_height),
52  preserve_range=False)
53  if(aspect == 1):
54  imgScaled = skimage.transform.resize(
55  img,
56  (input_width, input_height),
57  preserve_range=False)
58  return imgScaled
59 
60 
61 def load(img):
62  # load and transform image
63  img = skimage.img_as_float(skimage.io.imread(img)).astype(np.float32)
64  return img
65 
66 
67 def chw(img):
68  # switch to CHW
69  img = img.swapaxes(1, 2).swapaxes(0, 1)
70  return img
71 
72 
73 def bgr(img):
74  # switch to BGR
75  img = img[(2, 1, 0), :, :]
76  return img
77 
78 
79 def removeMean(img, mean):
80  # remove mean for better results
81  img = img * 255 - mean
82  return img
83 
84 
85 def batch(img):
86  # add batch size
87  img = img[np.newaxis, :, :, :].astype(np.float32)
88  return img
89 
90 
91 def parseResults(results):
92  results = np.asarray(results)
93  results = np.delete(results, 1)
94  index = 0
95  highest = 0
96  arr = np.empty((0, 2), dtype=object)
97  arr[:, 0] = int(10)
98  arr[:, 1:] = float(10)
99  for i, r in enumerate(results):
100  # imagenet index begins with 1!
101  i = i + 1
102  arr = np.append(arr, np.array([[i, r]]), axis=0)
103  if (r > highest):
104  highest = r
105  index = i
106 
107  # top 3 results
108  print("Raw top 3 results:", sorted(arr, key=lambda x: x[1], reverse=True)[:3])
109 
110  # now we can grab the code list
111  with open('inference_codes.txt', 'r') as f:
112  for line in f:
113  code, result = line.partition(":")[::2]
114  if (code.strip() == str(index)):
115  answer = "The image contains a %s with a %s percent probability." \
116  % (result.strip()[1:-2], highest * 100)
117  f.closed
118  return answer
119 
120 
121 def loadToNCHW(img, mean, input_size):
122  img = load(img)
123  img = rescale(img, input_size, input_size)
124  img = crop_center(img, input_size, input_size)
125  img = chw(img)
126  img = bgr(img)
127  img = removeMean(img, mean)
128  img = batch(img)
129  return img