Caffe2 - Python API
A deep learning, cross platform ML framework
hub.py
1 import importlib
2 import os
3 import shutil
4 import sys
5 import tempfile
6 import zipfile
7 
8 if sys.version_info[0] == 2:
9  from urlparse import urlparse
10  from urllib2 import urlopen # noqa f811
11 else:
12  from urllib.request import urlopen
13  from urllib.parse import urlparse
14 
15 import torch
16 import torch.utils.model_zoo as model_zoo
17 
18 MASTER_BRANCH = 'master'
19 ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
20 DEFAULT_TORCH_HUB_DIR = '~/.torch/hub'
21 READ_DATA_CHUNK = 8192
22 hub_dir = None
23 
24 
25 def _check_module_exists(name):
26  if sys.version_info >= (3, 4):
27  import importlib.util
28  return importlib.util.find_spec(name) is not None
29  elif sys.version_info >= (3, 3):
30  # Special case for python3.3
31  import importlib.find_loader
32  return importlib.find_loader(name) is not None
33  else:
34  # NB: imp doesn't handle hierarchical module names (names contains dots).
35  try:
36  import imp
37  imp.find_module(name)
38  except Exception:
39  return False
40  return True
41 
42 
43 def _remove_if_exists(path):
44  if os.path.exists(path):
45  if os.path.isfile(path):
46  os.remove(path)
47  else:
48  shutil.rmtree(path)
49 
50 
51 def _git_archive_link(repo, branch):
52  return 'https://github.com/' + repo + '/archive/' + branch + '.zip'
53 
54 
55 def _download_url_to_file(url, filename):
56  sys.stderr.write('Downloading: \"{}\" to {}'.format(url, filename))
57  response = urlopen(url)
58  with open(filename, 'wb') as f:
59  while True:
60  data = response.read(READ_DATA_CHUNK)
61  if len(data) == 0:
62  break
63  f.write(data)
64 
65 
66 def _load_attr_from_module(module_name, func_name):
67  m = importlib.import_module(module_name)
68  # Check if callable is defined in the module
69  if func_name not in dir(m):
70  return None
71  return getattr(m, func_name)
72 
73 
74 def set_dir(d):
75  r"""
76  Optionally set hub_dir to a local dir to save downloaded models & weights.
77 
78  If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first,
79  `~/.torch/hub` will be created and used as fallback.
80 
81  Args:
82  d: path to a local folder to save downloaded models & weights.
83  """
84  global hub_dir
85  hub_dir = d
86 
87 
88 def load(github, model, force_reload=False, *args, **kwargs):
89  r"""
90  Load a model from a github repo, with pretrained weights.
91 
92  Args:
93  github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
94  tag/branch. The default branch is `master` if not specified.
95  Example: 'pytorch/vision[:hub]'
96  model: Required, a string of entrypoint name defined in repo's hubconf.py
97  force_reload: Optional, whether to discard the existing cache and force a fresh download.
98  Default is `False`.
99  *args: Optional, the corresponding args for callable `model`.
100  **kwargs: Optional, the corresponding kwargs for callable `model`.
101 
102  Returns:
103  a single model with corresponding pretrained weights.
104  """
105 
106  if not isinstance(model, str):
107  raise ValueError('Invalid input: model should be a string of function name')
108 
109  # Setup hub_dir to save downloaded files
110  global hub_dir
111  if hub_dir is None:
112  hub_dir = os.getenv(ENV_TORCH_HUB_DIR, DEFAULT_TORCH_HUB_DIR)
113 
114  if '~' in hub_dir:
115  hub_dir = os.path.expanduser(hub_dir)
116 
117  if not os.path.exists(hub_dir):
118  os.makedirs(hub_dir)
119 
120  # Parse github repo information
121  branch = MASTER_BRANCH
122  if ':' in github:
123  repo_info, branch = github.split(':')
124  else:
125  repo_info = github
126  repo_owner, repo_name = repo_info.split('/')
127 
128  # Download zipped code from github
129  url = _git_archive_link(repo_info, branch)
130  cached_file = os.path.join(hub_dir, branch + '.zip')
131  repo_dir = os.path.join(hub_dir, repo_name + '_' + branch)
132 
133  use_cache = (not force_reload) and os.path.exists(repo_dir)
134 
135  # Github uses '{repo_name}-{branch_name}' as folder name which is not importable
136  # We need to manually rename it to '{repo_name}'
137  # Unzip the code and rename the base folder
138  if use_cache:
139  sys.stderr.write('Using cache found in {}'.format(repo_dir))
140  else:
141  _remove_if_exists(cached_file)
142  _download_url_to_file(url, cached_file)
143 
144  cached_zipfile = zipfile.ZipFile(cached_file)
145 
146  # Github renames folder repo-v1.x.x to repo-1.x.x
147  extraced_repo_name = cached_zipfile.infolist()[0].filename
148  extracted_repo = os.path.join(hub_dir, extraced_repo_name)
149  _remove_if_exists(extracted_repo)
150  cached_zipfile.extractall(hub_dir)
151 
152  _remove_if_exists(cached_file)
153  _remove_if_exists(repo_dir)
154  shutil.move(extracted_repo, repo_dir) # rename the repo
155 
156  sys.path.insert(0, repo_dir) # Make Python interpreter aware of the repo
157 
158  dependencies = _load_attr_from_module('hubconf', 'dependencies')
159 
160  if dependencies is not None:
161  missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
162  if len(missing_deps):
163  raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
164 
165  func = _load_attr_from_module('hubconf', model)
166  if func is None:
167  raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
168 
169  # Check if func is callable
170  if not callable(func):
171  raise RuntimeError('{} is not callable'.format(func))
172 
173  # Call the function
174  return func(*args, **kwargs)