Caffe2 - Python API
A deep learning, cross platform ML framework
model_zoo.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 import torch
3 
4 import errno
5 import hashlib
6 import os
7 import re
8 import shutil
9 import sys
10 import tempfile
11 
12 try:
13  from requests.utils import urlparse
14  from requests import get as urlopen
15  requests_available = True
16 except ImportError:
17  requests_available = False
18  if sys.version_info[0] == 2:
19  from urlparse import urlparse # noqa f811
20  from urllib2 import urlopen # noqa f811
21  else:
22  from urllib.request import urlopen
23  from urllib.parse import urlparse
24 try:
25  from tqdm import tqdm
26 except ImportError:
27  tqdm = None # defined below
28 
29 # matches bfd8deac from resnet18-bfd8deac.pth
30 HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
31 
32 
33 def load_url(url, model_dir=None, map_location=None, progress=True):
34  r"""Loads the Torch serialized object at the given URL.
35 
36  If the object is already present in `model_dir`, it's deserialized and
37  returned. The filename part of the URL should follow the naming convention
38  ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
39  digits of the SHA256 hash of the contents of the file. The hash is used to
40  ensure unique names and to verify the contents of the file.
41 
42  The default value of `model_dir` is ``$TORCH_HOME/models`` where
43  ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
44  overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
45 
46  Args:
47  url (string): URL of the object to download
48  model_dir (string, optional): directory in which to save the object
49  map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
50  progress (bool, optional): whether or not to display a progress bar to stderr
51 
52  Example:
53  >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
54 
55  """
56  if model_dir is None:
57  torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
58  model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
59 
60  try:
61  os.makedirs(model_dir)
62  except OSError as e:
63  if e.errno == errno.EEXIST:
64  # Directory already exists, ignore.
65  pass
66  else:
67  # Unexpected OSError, re-raise.
68  raise
69 
70  parts = urlparse(url)
71  filename = os.path.basename(parts.path)
72  cached_file = os.path.join(model_dir, filename)
73  if not os.path.exists(cached_file):
74  sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
75  hash_prefix = HASH_REGEX.search(filename).group(1)
76  _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
77  return torch.load(cached_file, map_location=map_location)
78 
79 
80 def _download_url_to_file(url, dst, hash_prefix, progress):
81  file_size = None
82  if requests_available:
83  u = urlopen(url, stream=True)
84  if hasattr(u.headers, "Content-Length"):
85  file_size = int(u.headers["Content-Length"])
86  u = u.raw
87  else:
88  u = urlopen(url)
89  meta = u.info()
90  if hasattr(meta, 'getheaders'):
91  content_length = meta.getheaders("Content-Length")
92  else:
93  content_length = meta.get_all("Content-Length")
94  if content_length is not None and len(content_length) > 0:
95  file_size = int(content_length[0])
96 
97  f = tempfile.NamedTemporaryFile(delete=False)
98  try:
99  if hash_prefix is not None:
100  sha256 = hashlib.sha256()
101  with tqdm(total=file_size, disable=not progress) as pbar:
102  while True:
103  buffer = u.read(8192)
104  if len(buffer) == 0:
105  break
106  f.write(buffer)
107  if hash_prefix is not None:
108  sha256.update(buffer)
109  pbar.update(len(buffer))
110 
111  f.close()
112  if hash_prefix is not None:
113  digest = sha256.hexdigest()
114  if digest[:len(hash_prefix)] != hash_prefix:
115  raise RuntimeError('invalid hash value (expected "{}", got "{}")'
116  .format(hash_prefix, digest))
117  shutil.move(f.name, dst)
118  finally:
119  f.close()
120  if os.path.exists(f.name):
121  os.remove(f.name)
122 
123 
124 if tqdm is None:
125  # fake tqdm if it's not installed
126  class tqdm(object):
127 
128  def __init__(self, total=None, disable=False):
129  self.total = total
130  self.disable = disable
131  self.n = 0
132 
133  def update(self, n):
134  if self.disable:
135  return
136 
137  self.n += n
138  if self.total is None:
139  sys.stderr.write("\r{0:.1f} bytes".format(self.n))
140  else:
141  sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
142  sys.stderr.flush()
143 
144  def __enter__(self):
145  return self
146 
147  def __exit__(self, exc_type, exc_val, exc_tb):
148  if self.disable:
149  return
150 
151  sys.stderr.write('\n')