3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 from caffe2.proto
import caffe2_pb2
18 import urllib.error
as urlliberror
19 import urllib.request
as urllib
20 HTTPError = urlliberror.HTTPError
21 URLError = urlliberror.URLError
23 import urllib2
as urllib
24 HTTPError = urllib.HTTPError
25 URLError = urllib.URLError
28 DOWNLOAD_BASE_URL =
"https://s3.amazonaws.com/download.caffe2.ai/models/" 33 def signalHandler(signal, frame):
34 print(
"Killing download...")
38 signal.signal(signal.SIGINT, signalHandler)
41 def deleteDirectory(top_dir):
42 for root, dirs, files
in os.walk(top_dir, topdown=
False):
44 os.remove(os.path.join(root, name))
46 os.rmdir(os.path.join(root, name))
50 def progressBar(percentage):
51 full = int(DOWNLOAD_COLUMNS * percentage / 100)
52 bar = full *
"#" + (DOWNLOAD_COLUMNS - full) *
" " 53 sys.stdout.write(
u"\u001b[1000D[" + bar +
"] " + str(percentage) +
"%")
57 def downloadFromURLToFile(url, filename, show_progress=True):
59 print(
"Downloading from {url}".format(url=url))
60 response = urllib.urlopen(url)
61 size = int(response.info().get(
'Content-Length').strip())
62 chunk = min(size, 8192)
63 print(
"Writing to {filename}".format(filename=filename))
67 with open(filename,
"wb")
as local_file:
69 data_chunk = response.read(chunk)
72 local_file.write(data_chunk)
74 downloaded_size += len(data_chunk)
75 progressBar(int(100 * downloaded_size / size))
77 except HTTPError
as e:
78 raise Exception(
"Could not download model. [HTTP Error] {code}: {reason}." 79 .format(code=e.code, reason=e.reason))
81 raise Exception(
"Could not download model. [URL Error] {reason}." 82 .format(reason=e.reason))
83 except Exception
as e:
87 def getURLFromName(name, filename):
88 return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
89 name=name, filename=filename)
92 def downloadModel(model, args):
94 model_folder =
'{folder}'.format(folder=model)
95 dir_path = os.path.dirname(os.path.realpath(__file__))
97 model_folder =
'{dir_path}/{folder}'.format(dir_path=dir_path,
101 if os.path.exists(model_folder)
and not os.path.isdir(model_folder):
103 raise Exception(
"Cannot create folder for storing the model,\ 104 there exists a file of the same name.")
106 print(
"Overwriting existing file! ({filename})" 107 .format(filename=model_folder))
108 os.remove(model_folder)
109 if os.path.isdir(model_folder):
112 query =
"Model already exists, continue? [y/N] " 114 response = raw_input(query)
116 response = input(query)
117 if response.upper() ==
'N' or not response:
118 print(
"Cancelling download...")
120 print(
"Overwriting existing folder! ({filename})".format(filename=model_folder))
121 deleteDirectory(model_folder)
124 os.makedirs(model_folder)
125 for f
in [
'predict_net.pb',
'init_net.pb']:
127 downloadFromURLToFile(getURLFromName(model, f),
128 '{folder}/{f}'.format(folder=model_folder,
130 except Exception
as e:
131 print(
"Abort: {reason}".format(reason=str(e)))
132 print(
"Cleaning up...")
133 deleteDirectory(model_folder)
137 os.symlink(
"{folder}/__sym_init__.py".format(folder=dir_path),
138 "{folder}/__init__.py".format(folder=model_folder))
141 def validModelName(name):
142 invalid_names = [
'__init__']
143 if name
in invalid_names:
145 if not re.match(
"^[/0-9a-zA-Z_-]+$", name):
150 def _model_dir(self, model):
151 caffe2_home = os.path.expanduser(os.getenv(
'CAFFE2_HOME',
'~/.caffe2'))
152 models_dir = os.getenv(
'CAFFE2_MODELS', os.path.join(caffe2_home,
'models'))
153 return os.path.join(models_dir, model)
155 def _download(self, model):
157 assert not os.path.exists(model_dir)
158 os.makedirs(model_dir)
159 for f
in [
'predict_net.pb',
'init_net.pb',
'value_info.json']:
160 url = getURLFromName(model, f)
161 dest = os.path.join(model_dir, f)
164 downloadFromURLToFile(url, dest,
170 downloadFromURLToFile(url, dest)
171 except Exception
as e:
172 print(
"Abort: {reason}".format(reason=e))
173 print(
"Cleaning up...")
174 deleteDirectory(model_dir)
177 def get_c2_model(self, model_name):
179 if not os.path.exists(model_dir):
181 c2_predict_pb = os.path.join(model_dir,
'predict_net.pb')
182 c2_predict_net = caffe2_pb2.NetDef()
183 with open(c2_predict_pb,
'rb')
as f:
184 c2_predict_net.ParseFromString(f.read())
185 c2_predict_net.name = model_name
187 c2_init_pb = os.path.join(model_dir,
'init_net.pb')
188 c2_init_net = caffe2_pb2.NetDef()
189 with open(c2_init_pb,
'rb')
as f:
190 c2_init_net.ParseFromString(f.read())
191 c2_init_net.name = model_name +
'_init' 193 with open(os.path.join(model_dir,
'value_info.json'))
as f:
194 value_info = json.load(f)
195 return c2_init_net, c2_predict_net, value_info
198 if __name__ ==
"__main__":
199 parser = argparse.ArgumentParser(
200 description=
'Download or install pretrained models.')
201 parser.add_argument(
'model', nargs=
'+',
202 help=
'Model to download/install.')
203 parser.add_argument(
'-i',
'--install', action=
'store_true',
204 help=
'Install the model.')
205 parser.add_argument(
'-f',
'--force', action=
'store_true',
206 help=
'Force a download/installation.')
207 args = parser.parse_args()
208 for model
in args.model:
209 if validModelName(model):
210 downloadModel(model, args)
212 print(
"'{}' is not a valid model name.".format(model))
def _download(self, model)
def _model_dir(self, model)