1 from __future__
import absolute_import, division, print_function, unicode_literals
13 from requests.utils
import urlparse
14 from requests
import get
as urlopen
15 requests_available =
True 17 requests_available =
False 18 if sys.version_info[0] == 2:
19 from urlparse
import urlparse
20 from urllib2
import urlopen
22 from urllib.request
import urlopen
23 from urllib.parse
import urlparse
30 HASH_REGEX = re.compile(
r'-([a-f0-9]*)\.')
33 def load_url(url, model_dir=None, map_location=None, progress=True):
34 r"""Loads the Torch serialized object at the given URL. 36 If the object is already present in `model_dir`, it's deserialized and 37 returned. The filename part of the URL should follow the naming convention 38 ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more 39 digits of the SHA256 hash of the contents of the file. The hash is used to 40 ensure unique names and to verify the contents of the file. 42 The default value of `model_dir` is ``$TORCH_HOME/models`` where 43 ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be 44 overridden with the ``$TORCH_MODEL_ZOO`` environment variable. 47 url (string): URL of the object to download 48 model_dir (string, optional): directory in which to save the object 49 map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) 50 progress (bool, optional): whether or not to display a progress bar to stderr 53 >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 57 torch_home = os.path.expanduser(os.getenv(
'TORCH_HOME',
'~/.torch'))
58 model_dir = os.getenv(
'TORCH_MODEL_ZOO', os.path.join(torch_home,
'models'))
61 os.makedirs(model_dir)
63 if e.errno == errno.EEXIST:
71 filename = os.path.basename(parts.path)
72 cached_file = os.path.join(model_dir, filename)
73 if not os.path.exists(cached_file):
74 sys.stderr.write(
'Downloading: "{}" to {}\n'.format(url, cached_file))
75 hash_prefix = HASH_REGEX.search(filename).group(1)
76 _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
77 return torch.load(cached_file, map_location=map_location)
80 def _download_url_to_file(url, dst, hash_prefix, progress):
82 if requests_available:
83 u = urlopen(url, stream=
True)
84 if hasattr(u.headers,
"Content-Length"):
85 file_size = int(u.headers[
"Content-Length"])
90 if hasattr(meta,
'getheaders'):
91 content_length = meta.getheaders(
"Content-Length")
93 content_length = meta.get_all(
"Content-Length")
94 if content_length
is not None and len(content_length) > 0:
95 file_size = int(content_length[0])
97 f = tempfile.NamedTemporaryFile(delete=
False)
99 if hash_prefix
is not None:
100 sha256 = hashlib.sha256()
101 with tqdm(total=file_size, disable=
not progress)
as pbar:
103 buffer = u.read(8192)
107 if hash_prefix
is not None:
108 sha256.update(buffer)
109 pbar.update(len(buffer))
112 if hash_prefix
is not None:
113 digest = sha256.hexdigest()
114 if digest[:len(hash_prefix)] != hash_prefix:
115 raise RuntimeError(
'invalid hash value (expected "{}", got "{}")' 116 .format(hash_prefix, digest))
117 shutil.move(f.name, dst)
120 if os.path.exists(f.name):
128 def __init__(self, total=None, disable=False):
138 if self.
total is None:
139 sys.stderr.write(
"\r{0:.1f} bytes".format(self.
n))
141 sys.stderr.write(
"\r{0:.1f}%".format(100 * self.
n / float(self.
total)))
147 def __exit__(self, exc_type, exc_val, exc_tb):
151 sys.stderr.write(
'\n')