3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 from six.moves.urllib.request
import urlretrieve
21 def _download(self, model):
23 assert not os.path.exists(model_dir)
24 os.makedirs(model_dir)
25 for f
in [
'predict_net.pb',
'init_net.pb',
'value_info.json']:
26 url = getURLFromName(model, f)
27 dest = os.path.join(model_dir, f)
30 downloadFromURLToFile(url, dest,
36 downloadFromURLToFile(url, dest)
37 except Exception
as e:
38 print(
"Abort: {reason}".format(reason=e))
39 print(
"Cleaning up...")
40 deleteDirectory(model_dir)
43 def _caffe2_model_dir(self, model):
44 caffe2_home = os.path.expanduser(
'~/.caffe2')
45 models_dir = os.path.join(caffe2_home,
'models')
46 return os.path.join(models_dir, model)
48 def _onnx_model_dir(self, model):
49 onnx_home = os.path.expanduser(
'~/.onnx')
50 models_dir = os.path.join(onnx_home,
'models')
51 model_dir = os.path.join(models_dir, model)
52 return model_dir, os.path.dirname(model_dir)
56 def _prepare_model_data(self, model):
58 if os.path.exists(model_dir):
60 os.makedirs(model_dir)
61 url =
'https://s3.amazonaws.com/download.onnx/models/{}.tar.gz'.format(model)
65 download_file = tempfile.NamedTemporaryFile(delete=
False)
68 print(
'Start downloading model {} from {}'.format(model, url))
69 urlretrieve(url, download_file.name)
71 with tarfile.open(download_file.name)
as t:
72 t.extractall(models_dir)
73 except Exception
as e:
74 print(
'Failed to prepare data for model {}: {}'.format(model, e))
77 os.remove(download_file.name)
96 def download_models():
99 print(
'update-caffe2-models.py: downloading', model)
100 caffe2_model_dir = sc._caffe2_model_dir(model)
101 onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
102 if not os.path.exists(caffe2_model_dir):
104 if not os.path.exists(onnx_model_dir):
105 sc._prepare_model_data(model)
107 def generate_models():
110 print(
'update-caffe2-models.py: generating', model)
111 caffe2_model_dir = sc._caffe2_model_dir(model)
112 onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
113 subprocess.check_call([
'echo', model])
114 with open(os.path.join(caffe2_model_dir,
'value_info.json'),
'r') as f: 115 value_info = f.read() 116 subprocess.check_call([ 117 'convert-caffe2-to-onnx',
118 '--caffe2-net-name', model,
119 '--caffe2-init-net', os.path.join(caffe2_model_dir,
'init_net.pb'),
120 '--value-info', value_info,
121 '-o', os.path.join(onnx_model_dir,
'model.pb'),
122 os.path.join(caffe2_model_dir,
'predict_net.pb')
124 subprocess.check_call([
129 ], cwd=onnx_models_dir)
134 print(
'update-caffe2-models.py: uploading', model)
135 onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
136 subprocess.check_call([
141 "s3://download.onnx/models/{}.tar.gz".format(model),
142 '--acl',
'public-read' 143 ], cwd=onnx_models_dir)
148 onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
149 os.remove(os.path.join(os.path.dirname(onnx_model_dir), model +
'.tar.gz'))
151 if __name__ ==
'__main__':
153 subprocess.check_call([
'aws',
'sts',
'get-caller-identity'])
155 print(
'update-caffe2-models.py: please run `aws configure` manually to set up credentials')
157 if sys.argv[1] ==
'download':
159 if sys.argv[1] ==
'generate':
161 elif sys.argv[1] ==
'upload':
163 elif sys.argv[1] ==
'cleanup':
def _caffe2_model_dir(self, model)
def _onnx_model_dir(self, model)