Caffe2 - Python API
A deep learning, cross platform ML framework
download.py
1 ## @package download
2 # Module caffe2.python.models.download
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import os
9 import sys
10 import signal
11 import re
12 import json
13 
14 from caffe2.proto import caffe2_pb2
15 
16 # Import urllib
17 try:
18  import urllib.error as urlliberror
19  import urllib.request as urllib
20  HTTPError = urlliberror.HTTPError
21  URLError = urlliberror.URLError
22 except ImportError:
23  import urllib2 as urllib
24  HTTPError = urllib.HTTPError
25  URLError = urllib.URLError
26 
27 # urllib requires more work to deal with a redirect, so not using vanity url
28 DOWNLOAD_BASE_URL = "https://s3.amazonaws.com/download.caffe2.ai/models/"
29 DOWNLOAD_COLUMNS = 70
30 
31 
32 # Don't let urllib hang up on big downloads
33 def signalHandler(signal, frame):
34  print("Killing download...")
35  exit(0)
36 
37 
38 signal.signal(signal.SIGINT, signalHandler)
39 
40 
41 def deleteDirectory(top_dir):
42  for root, dirs, files in os.walk(top_dir, topdown=False):
43  for name in files:
44  os.remove(os.path.join(root, name))
45  for name in dirs:
46  os.rmdir(os.path.join(root, name))
47  os.rmdir(top_dir)
48 
49 
50 def progressBar(percentage):
51  full = int(DOWNLOAD_COLUMNS * percentage / 100)
52  bar = full * "#" + (DOWNLOAD_COLUMNS - full) * " "
53  sys.stdout.write(u"\u001b[1000D[" + bar + "] " + str(percentage) + "%")
54  sys.stdout.flush()
55 
56 
57 def downloadFromURLToFile(url, filename, show_progress=True):
58  try:
59  print("Downloading from {url}".format(url=url))
60  response = urllib.urlopen(url)
61  size = int(response.info().get('Content-Length').strip())
62  chunk = min(size, 8192)
63  print("Writing to {filename}".format(filename=filename))
64  if show_progress:
65  downloaded_size = 0
66  progressBar(0)
67  with open(filename, "wb") as local_file:
68  while True:
69  data_chunk = response.read(chunk)
70  if not data_chunk:
71  break
72  local_file.write(data_chunk)
73  if show_progress:
74  downloaded_size += len(data_chunk)
75  progressBar(int(100 * downloaded_size / size))
76  print("") # New line to fix for progress bar
77  except HTTPError as e:
78  raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
79  .format(code=e.code, reason=e.reason))
80  except URLError as e:
81  raise Exception("Could not download model. [URL Error] {reason}."
82  .format(reason=e.reason))
83  except Exception as e:
84  raise e
85 
86 
87 def getURLFromName(name, filename):
88  return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
89  name=name, filename=filename)
90 
91 
92 def downloadModel(model, args):
93  # Figure out where to store the model
94  model_folder = '{folder}'.format(folder=model)
95  dir_path = os.path.dirname(os.path.realpath(__file__))
96  if args.install:
97  model_folder = '{dir_path}/{folder}'.format(dir_path=dir_path,
98  folder=model)
99 
100  # Check if that folder is already there
101  if os.path.exists(model_folder) and not os.path.isdir(model_folder):
102  if not args.force:
103  raise Exception("Cannot create folder for storing the model,\
104  there exists a file of the same name.")
105  else:
106  print("Overwriting existing file! ({filename})"
107  .format(filename=model_folder))
108  os.remove(model_folder)
109  if os.path.isdir(model_folder):
110  if not args.force:
111  response = ""
112  query = "Model already exists, continue? [y/N] "
113  try:
114  response = raw_input(query)
115  except NameError:
116  response = input(query)
117  if response.upper() == 'N' or not response:
118  print("Cancelling download...")
119  exit(0)
120  print("Overwriting existing folder! ({filename})".format(filename=model_folder))
121  deleteDirectory(model_folder)
122 
123  # Now we can safely create the folder and download the model
124  os.makedirs(model_folder)
125  for f in ['predict_net.pb', 'init_net.pb']:
126  try:
127  downloadFromURLToFile(getURLFromName(model, f),
128  '{folder}/{f}'.format(folder=model_folder,
129  f=f))
130  except Exception as e:
131  print("Abort: {reason}".format(reason=str(e)))
132  print("Cleaning up...")
133  deleteDirectory(model_folder)
134  exit(0)
135 
136  if args.install:
137  os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),
138  "{folder}/__init__.py".format(folder=model_folder))
139 
140 
141 def validModelName(name):
142  invalid_names = ['__init__']
143  if name in invalid_names:
144  return False
145  if not re.match("^[/0-9a-zA-Z_-]+$", name):
146  return False
147  return True
148 
150  def _model_dir(self, model):
151  caffe2_home = os.path.expanduser(os.getenv('CAFFE2_HOME', '~/.caffe2'))
152  models_dir = os.getenv('CAFFE2_MODELS', os.path.join(caffe2_home, 'models'))
153  return os.path.join(models_dir, model)
154 
155  def _download(self, model):
156  model_dir = self._model_dir(model)
157  assert not os.path.exists(model_dir)
158  os.makedirs(model_dir)
159  for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
160  url = getURLFromName(model, f)
161  dest = os.path.join(model_dir, f)
162  try:
163  try:
164  downloadFromURLToFile(url, dest,
165  show_progress=False)
166  except TypeError:
167  # show_progress not supported prior to
168  # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
169  # (Sep 17, 2017)
170  downloadFromURLToFile(url, dest)
171  except Exception as e:
172  print("Abort: {reason}".format(reason=e))
173  print("Cleaning up...")
174  deleteDirectory(model_dir)
175  exit(1)
176 
177  def get_c2_model(self, model_name):
178  model_dir = self._model_dir(model_name)
179  if not os.path.exists(model_dir):
180  self._download(model_name)
181  c2_predict_pb = os.path.join(model_dir, 'predict_net.pb')
182  c2_predict_net = caffe2_pb2.NetDef()
183  with open(c2_predict_pb, 'rb') as f:
184  c2_predict_net.ParseFromString(f.read())
185  c2_predict_net.name = model_name
186 
187  c2_init_pb = os.path.join(model_dir, 'init_net.pb')
188  c2_init_net = caffe2_pb2.NetDef()
189  with open(c2_init_pb, 'rb') as f:
190  c2_init_net.ParseFromString(f.read())
191  c2_init_net.name = model_name + '_init'
192 
193  with open(os.path.join(model_dir, 'value_info.json')) as f:
194  value_info = json.load(f)
195  return c2_init_net, c2_predict_net, value_info
196 
197 
198 if __name__ == "__main__":
199  parser = argparse.ArgumentParser(
200  description='Download or install pretrained models.')
201  parser.add_argument('model', nargs='+',
202  help='Model to download/install.')
203  parser.add_argument('-i', '--install', action='store_true',
204  help='Install the model.')
205  parser.add_argument('-f', '--force', action='store_true',
206  help='Force a download/installation.')
207  args = parser.parse_args()
208  for model in args.model:
209  if validModelName(model):
210  downloadModel(model, args)
211  else:
212  print("'{}' is not a valid model name.".format(model))