8 if sys.version_info[0] == 2:
9 from urlparse
import urlparse
10 from urllib2
import urlopen
12 from urllib.request
import urlopen
13 from urllib.parse
import urlparse
18 MASTER_BRANCH =
'master' 19 ENV_TORCH_HUB_DIR =
'TORCH_HUB_DIR' 20 DEFAULT_TORCH_HUB_DIR =
'~/.torch/hub' 21 READ_DATA_CHUNK = 8192
25 def _check_module_exists(name):
26 if sys.version_info >= (3, 4):
28 return importlib.util.find_spec(name)
is not None 29 elif sys.version_info >= (3, 3):
31 import importlib.find_loader
32 return importlib.find_loader(name)
is not None 43 def _remove_if_exists(path):
44 if os.path.exists(path):
45 if os.path.isfile(path):
51 def _git_archive_link(repo, branch):
52 return 'https://github.com/' + repo +
'/archive/' + branch +
'.zip' 55 def _download_url_to_file(url, filename):
56 sys.stderr.write(
'Downloading: \"{}\" to {}'.format(url, filename))
57 response = urlopen(url)
58 with open(filename,
'wb')
as f:
60 data = response.read(READ_DATA_CHUNK)
66 def _load_attr_from_module(module_name, func_name):
67 m = importlib.import_module(module_name)
69 if func_name
not in dir(m):
71 return getattr(m, func_name)
76 Optionally set hub_dir to a local dir to save downloaded models & weights. 78 If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first, 79 `~/.torch/hub` will be created and used as fallback. 82 d: path to a local folder to save downloaded models & weights. 88 def load(github, model, force_reload=False, *args, **kwargs):
90 Load a model from a github repo, with pretrained weights. 93 github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional 94 tag/branch. The default branch is `master` if not specified. 95 Example: 'pytorch/vision[:hub]' 96 model: Required, a string of entrypoint name defined in repo's hubconf.py 97 force_reload: Optional, whether to discard the existing cache and force a fresh download. 99 *args: Optional, the corresponding args for callable `model`. 100 **kwargs: Optional, the corresponding kwargs for callable `model`. 103 a single model with corresponding pretrained weights. 106 if not isinstance(model, str):
107 raise ValueError(
'Invalid input: model should be a string of function name')
112 hub_dir = os.getenv(ENV_TORCH_HUB_DIR, DEFAULT_TORCH_HUB_DIR)
115 hub_dir = os.path.expanduser(hub_dir)
117 if not os.path.exists(hub_dir):
121 branch = MASTER_BRANCH
123 repo_info, branch = github.split(
':')
126 repo_owner, repo_name = repo_info.split(
'/')
129 url = _git_archive_link(repo_info, branch)
130 cached_file = os.path.join(hub_dir, branch +
'.zip')
131 repo_dir = os.path.join(hub_dir, repo_name +
'_' + branch)
133 use_cache = (
not force_reload)
and os.path.exists(repo_dir)
139 sys.stderr.write(
'Using cache found in {}'.format(repo_dir))
141 _remove_if_exists(cached_file)
142 _download_url_to_file(url, cached_file)
144 cached_zipfile = zipfile.ZipFile(cached_file)
147 extraced_repo_name = cached_zipfile.infolist()[0].filename
148 extracted_repo = os.path.join(hub_dir, extraced_repo_name)
149 _remove_if_exists(extracted_repo)
150 cached_zipfile.extractall(hub_dir)
152 _remove_if_exists(cached_file)
153 _remove_if_exists(repo_dir)
154 shutil.move(extracted_repo, repo_dir)
156 sys.path.insert(0, repo_dir)
158 dependencies = _load_attr_from_module(
'hubconf',
'dependencies')
160 if dependencies
is not None:
161 missing_deps = [pkg
for pkg
in dependencies
if not _check_module_exists(pkg)]
162 if len(missing_deps):
163 raise RuntimeError(
'Missing dependencies: {}'.format(
', '.join(missing_deps)))
165 func = _load_attr_from_module(
'hubconf', model)
167 raise RuntimeError(
'Cannot find callable {} in hubconf'.format(model))
170 if not callable(func):
171 raise RuntimeError(
'{} is not callable'.format(func))
174 return func(*args, **kwargs)