Caffe2 - Python API
A deep learning, cross platform ML framework
serialization.py
1 import difflib
2 import inspect
3 import os
4 import io
5 import shutil
6 import struct
7 import sys
8 import torch
9 import tarfile
10 import zipfile
11 import tempfile
12 import warnings
13 from contextlib import closing, contextmanager
14 from ._utils import _import_dotted_name
15 from ._six import string_classes as _string_classes
16 if sys.version_info[0] == 2:
17  import cPickle as pickle
18 else:
19  import pickle
20  import pathlib
21 
22 DEFAULT_PROTOCOL = 2
23 
24 LONG_SIZE = struct.Struct('=l').size
25 INT_SIZE = struct.Struct('=i').size
26 SHORT_SIZE = struct.Struct('=h').size
27 
28 MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
29 PROTOCOL_VERSION = 1001
30 STORAGE_KEY_SEPARATOR = ','
31 
32 
33 class SourceChangeWarning(Warning):
34  pass
35 
36 
37 @contextmanager
38 def mkdtemp():
39  path = tempfile.mkdtemp()
40  yield path
41  shutil.rmtree(path)
42 
43 
44 _package_registry = []
45 
46 
47 def register_package(priority, tagger, deserializer):
48  queue_elem = (priority, tagger, deserializer)
49  _package_registry.append(queue_elem)
50  _package_registry.sort()
51 
52 
53 def _cpu_tag(obj):
54  if type(obj).__module__ == 'torch':
55  return 'cpu'
56 
57 
58 def _cuda_tag(obj):
59  if type(obj).__module__ == 'torch.cuda':
60  return 'cuda:' + str(obj.get_device())
61 
62 
63 def _cpu_deserialize(obj, location):
64  if location == 'cpu':
65  return obj
66 
67 
68 def validate_cuda_device(location):
69  if isinstance(location, torch.device):
70  location = str(location)
71  if not isinstance(location, _string_classes):
72  raise ValueError("location should be a string or torch.device")
73  if location[5:] == '':
74  device = 0
75  else:
76  device = max(int(location[5:]), 0)
77 
78  if not torch.cuda.is_available():
79  raise RuntimeError('Attempting to deserialize object on a CUDA '
80  'device but torch.cuda.is_available() is False. '
81  'If you are running on a CPU-only machine, '
82  'please use torch.load with map_location=\'cpu\' '
83  'to map your storages to the CPU.')
84  if device >= torch.cuda.device_count():
85  raise RuntimeError('Attempting to deserialize object on CUDA device '
86  '{} but torch.cuda.device_count() is {}. Please use '
87  'torch.load with map_location to map your storages '
88  'to an existing device.'.format(
89  device, torch.cuda.device_count()))
90  return device
91 
92 
93 def _cuda_deserialize(obj, location):
94  if location.startswith('cuda'):
95  device = validate_cuda_device(location)
96  if getattr(obj, "_torch_load_uninitialized", False):
97  storage_type = getattr(torch.cuda, type(obj).__name__)
98  with torch.cuda.device(device):
99  return storage_type(obj.size())
100  else:
101  return obj.cuda(device)
102 
103 
104 register_package(10, _cpu_tag, _cpu_deserialize)
105 register_package(20, _cuda_tag, _cuda_deserialize)
106 
107 
108 def location_tag(storage):
109  for _, tagger, _ in _package_registry:
110  location = tagger(storage)
111  if location:
112  return location
113  raise RuntimeError("don't know how to determine data location of " +
114  torch.typename(storage))
115 
116 
117 def default_restore_location(storage, location):
118  for _, _, fn in _package_registry:
119  result = fn(storage, location)
120  if result is not None:
121  return result
122  raise RuntimeError("don't know how to restore data location of " +
123  torch.typename(storage) + " (tagged with " +
124  location + ")")
125 
126 
127 def normalize_storage_type(storage_type):
128  return getattr(torch, storage_type.__name__)
129 
130 
131 def storage_to_tensor_type(storage):
132  storage_type = type(storage)
133  module = _import_dotted_name(storage_type.__module__)
134  return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
135 
136 
137 def _with_file_like(f, mode, body):
138  """
139  Executes a body function with a file object for f, opening
140  it in 'mode' if it is a string filename.
141  """
142  new_fd = False
143  if isinstance(f, str) or \
144  (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
145  (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
146  new_fd = True
147  f = open(f, mode)
148  try:
149  return body(f)
150  finally:
151  if new_fd:
152  f.close()
153 
154 
155 def _is_compressed_file(f):
156  compress_modules = ['gzip']
157  try:
158  return f.__module__ in compress_modules
159  except AttributeError:
160  return False
161 
162 
163 def _should_read_directly(f):
164  """
165  Checks if f is a file that should be read directly. It should be read
166  directly if it is backed by a real file (has a fileno) and is not a
167  a compressed file (e.g. gzip)
168  """
169  if _is_compressed_file(f):
170  return False
171  try:
172  return f.fileno() >= 0
173  except io.UnsupportedOperation:
174  return False
175  except AttributeError:
176  return False
177 
178 
179 def _check_seekable(f):
180 
181  def raise_err_msg(patterns, e):
182  for p in patterns:
183  if p in str(e):
184  msg = (str(e) + ". You can only torch.load from a file that is seekable." +
185  " Please pre-load the data into a buffer like io.BytesIO and" +
186  " try to load from it instead.")
187  raise type(e)(msg)
188  raise e
189 
190  try:
191  f.seek(f.tell())
192  return True
193  except (io.UnsupportedOperation, AttributeError) as e:
194  raise_err_msg(["seek", "tell"], e)
195 
196 
197 def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
198  """Saves an object to a disk file.
199 
200  See also: :ref:`recommend-saving-models`
201 
202  Args:
203  obj: saved object
204  f: a file-like object (has to implement write and flush) or a string
205  containing a file name
206  pickle_module: module used for pickling metadata and objects
207  pickle_protocol: can be specified to override the default protocol
208 
209  .. warning::
210  If you are using Python 2, torch.save does NOT support StringIO.StringIO
211  as a valid file-like object. This is because the write method should return
212  the number of bytes written; StringIO.write() does not do this.
213 
214  Please use something like io.BytesIO instead.
215 
216  Example:
217  >>> # Save to file
218  >>> x = torch.tensor([0, 1, 2, 3, 4])
219  >>> torch.save(x, 'tensor.pt')
220  >>> # Save to io.BytesIO buffer
221  >>> buffer = io.BytesIO()
222  >>> torch.save(x, buffer)
223  """
224  return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
225 
226 
227 def _save(obj, f, pickle_module, pickle_protocol):
228  if sys.version_info[0] == 2:
229  import StringIO
230  if isinstance(f, StringIO.StringIO):
231  msg = ('torch.save received unsupported StringIO.StringIO file object, whose '
232  'write method does not return the number of bytes written. '
233  'Please use something like io.BytesIO for torch.save instead.')
234  raise RuntimeError(msg)
235 
236  import torch.nn as nn
237  serialized_container_types = {}
238  serialized_storages = {}
239 
240  def persistent_id(obj):
241  # FIXME: the docs say that persistent_id should only return a string
242  # but torch store returns tuples. This works only in the binary protocol
243  # see
244  # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
245  # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
246  if isinstance(obj, type) and issubclass(obj, nn.Module):
247  if obj in serialized_container_types:
248  return None
249  serialized_container_types[obj] = True
250  source_file = source = None
251  try:
252  source_file = inspect.getsourcefile(obj)
253  source = inspect.getsource(obj)
254  except Exception: # saving the source is optional, so we can ignore any errors
255  warnings.warn("Couldn't retrieve source code for container of "
256  "type " + obj.__name__ + ". It won't be checked "
257  "for correctness upon loading.")
258  return ('module', obj, source_file, source)
259  elif torch.is_storage(obj):
260  storage_type = normalize_storage_type(type(obj))
261  # Offset is always 0, but we keep it for backwards compatibility
262  # with the old serialization format (which supported storage views)
263  offset = 0
264  obj_key = str(obj._cdata)
265  location = location_tag(obj)
266  serialized_storages[obj_key] = obj
267  is_view = obj._cdata != obj._cdata
268  if is_view:
269  view_metadata = (str(obj._cdata), offset, obj.size())
270  else:
271  view_metadata = None
272 
273  return ('storage',
274  storage_type,
275  obj_key,
276  location,
277  obj.size(),
278  view_metadata)
279 
280  return None
281 
282  sys_info = dict(
283  protocol_version=PROTOCOL_VERSION,
284  little_endian=sys.byteorder == 'little',
285  type_sizes=dict(
286  short=SHORT_SIZE,
287  int=INT_SIZE,
288  long=LONG_SIZE,
289  ),
290  )
291 
292  pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
293  pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
294  pickle_module.dump(sys_info, f, protocol=pickle_protocol)
295  pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
296  pickler.persistent_id = persistent_id
297  pickler.dump(obj)
298 
299  serialized_storage_keys = sorted(serialized_storages.keys())
300  pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
301  f.flush()
302  for key in serialized_storage_keys:
303  serialized_storages[key]._write_file(f, _should_read_directly(f))
304 
305 
306 def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
307  """Loads an object saved with :func:`torch.save` from a file.
308 
309  :meth:`torch.load` uses Python's unpickling facilities but treats storages,
310  which underlie tensors, specially. They are first deserialized on the
311  CPU and are then moved to the device they were saved from. If this fails
312  (e.g. because the run time system doesn't have certain devices), an exception
313  is raised. However, storages can be dynamically remapped to an alternative
314  set of devices using the `map_location` argument.
315 
316  If `map_location` is a callable, it will be called once for each serialized
317  storage with two arguments: storage and location. The storage argument
318  will be the initial deserialization of the storage, residing on the CPU.
319  Each serialized storage has a location tag associated with it which
320  identifies the device it was saved from, and this tag is the second
321  argument passed to map_location. The builtin location tags are `'cpu'` for
322  CPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors.
323  `map_location` should return either None or a storage. If `map_location` returns
324  a storage, it will be used as the final deserialized object, already moved to
325  the right device. Otherwise, :math:`torch.load` will fall back to the default
326  behavior, as if `map_location` wasn't specified.
327 
328  If `map_location` is a string, it should be a device tag, where all tensors
329  should be loaded.
330 
331  Otherwise, if `map_location` is a dict, it will be used to remap location tags
332  appearing in the file (keys), to ones that specify where to put the
333  storages (values).
334 
335  User extensions can register their own location tags and tagging and
336  deserialization methods using `register_package`.
337 
338  Args:
339  f: a file-like object (has to implement read, readline, tell, and seek),
340  or a string containing a file name
341  map_location: a function, torch.device, string or a dict specifying how to remap storage
342  locations
343  pickle_module: module used for unpickling metadata and objects (has to
344  match the pickle_module used to serialize file)
345  pickle_load_args: optional keyword arguments passed over to
346  ``pickle_module.load`` and ``pickle_module.Unpickler``, e.g.,
347  ``encoding=...``.
348 
349  .. note::
350  When you call :meth:`torch.load()` on a file which contains GPU tensors, those tensors
351  will be loaded to GPU by default. You can call `torch.load(.., map_location='cpu')`
352  and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
353 
354  .. note::
355  In Python 3, when loading files saved by Python 2, you may encounter
356  ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``. This is
357  caused by the difference of handling in byte strings in Python2 and
358  Python 3. You may use extra ``encoding`` keyword argument to specify how
359  these objects should be loaded, e.g., ``encoding='latin1'`` decodes them
360  to strings using ``latin1`` encoding, and ``encoding='bytes'`` keeps them
361  as byte arrays which can be decoded later with ``byte_array.decode(...)``.
362 
363  Example:
364  >>> torch.load('tensors.pt')
365  # Load all tensors onto the CPU
366  >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
367  # Load all tensors onto the CPU, using a function
368  >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
369  # Load all tensors onto GPU 1
370  >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
371  # Map tensors from GPU 1 to GPU 0
372  >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
373  # Load tensor from io.BytesIO object
374  >>> with open('tensor.pt', 'rb') as f:
375  buffer = io.BytesIO(f.read())
376  >>> torch.load(buffer)
377  """
378  new_fd = False
379  if isinstance(f, str) or \
380  (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
381  (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
382  new_fd = True
383  f = open(f, 'rb')
384  try:
385  return _load(f, map_location, pickle_module, **pickle_load_args)
386  finally:
387  if new_fd:
388  f.close()
389 
390 
391 def _load(f, map_location, pickle_module, **pickle_load_args):
392  deserialized_objects = {}
393 
394  if map_location is None:
395  restore_location = default_restore_location
396  elif isinstance(map_location, dict):
397  def restore_location(storage, location):
398  location = map_location.get(location, location)
399  return default_restore_location(storage, location)
400  elif isinstance(map_location, _string_classes):
401  def restore_location(storage, location):
402  return default_restore_location(storage, map_location)
403  elif isinstance(map_location, torch.device):
404  def restore_location(storage, location):
405  return default_restore_location(storage, str(map_location))
406  else:
407  def restore_location(storage, location):
408  result = map_location(storage, location)
409  if result is None:
410  result = default_restore_location(storage, location)
411  return result
412 
413  def _check_container_source(container_type, source_file, original_source):
414  try:
415  current_source = inspect.getsource(container_type)
416  except Exception: # saving the source is optional, so we can ignore any errors
417  warnings.warn("Couldn't retrieve source code for container of "
418  "type " + container_type.__name__ + ". It won't be checked "
419  "for correctness upon loading.")
420  return
421  if original_source != current_source:
422  if container_type.dump_patches:
423  file_name = container_type.__name__ + '.patch'
424  diff = difflib.unified_diff(current_source.split('\n'),
425  original_source.split('\n'),
426  source_file,
427  source_file, lineterm="")
428  lines = '\n'.join(diff)
429  try:
430  with open(file_name, 'a+') as f:
431  file_size = f.seek(0, 2)
432  f.seek(0)
433  if file_size == 0:
434  f.write(lines)
435  elif file_size != len(lines) or f.read() != lines:
436  raise IOError
437  msg = ("Saved a reverse patch to " + file_name + ". "
438  "Run `patch -p0 < " + file_name + "` to revert your "
439  "changes.")
440  except IOError:
441  msg = ("Tried to save a patch, but couldn't create a "
442  "writable file " + file_name + ". Make sure it "
443  "doesn't exist and your working directory is "
444  "writable.")
445  else:
446  msg = ("you can retrieve the original source code by "
447  "accessing the object's source attribute or set "
448  "`torch.nn.Module.dump_patches = True` and use the "
449  "patch tool to revert the changes.")
450  msg = ("source code of class '{}' has changed. {}"
451  .format(torch.typename(container_type), msg))
452  warnings.warn(msg, SourceChangeWarning)
453 
454  def legacy_load(f):
455  deserialized_objects = {}
456 
457  def persistent_load(saved_id):
458  if isinstance(saved_id, tuple):
459  # Ignore containers that don't have any sources saved
460  if all(saved_id[1:]):
461  _check_container_source(*saved_id)
462  return saved_id[0]
463  return deserialized_objects[int(saved_id)]
464 
465  with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
466  mkdtemp() as tmpdir:
467 
468  tar.extract('storages', path=tmpdir)
469  with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
470  num_storages = pickle_module.load(f, **pickle_load_args)
471  for i in range(num_storages):
472  args = pickle_module.load(f, **pickle_load_args)
473  key, location, storage_type = args
474  obj = storage_type._new_with_file(f)
475  obj = restore_location(obj, location)
476  deserialized_objects[key] = obj
477 
478  storage_views = pickle_module.load(f, **pickle_load_args)
479  for target_cdata, root_cdata, offset, size in storage_views:
480  root = deserialized_objects[root_cdata]
481  deserialized_objects[target_cdata] = root[offset:offset + size]
482 
483  tar.extract('tensors', path=tmpdir)
484  with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
485  num_tensors = pickle_module.load(f, **pickle_load_args)
486  for _ in range(num_tensors):
487  args = pickle_module.load(f, **pickle_load_args)
488  key, storage_id, original_tensor_type = args
489  storage = deserialized_objects[storage_id]
490  tensor_type = storage_to_tensor_type(storage)
491  ndim, = struct.unpack('<i', f.read(4))
492  # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
493  f.read(4)
494  size = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
495  stride = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
496  storage_offset, = struct.unpack('<q', f.read(8))
497  tensor = tensor_type().set_(storage, storage_offset, size, stride)
498  deserialized_objects[key] = tensor
499 
500  pickle_file = tar.extractfile('pickle')
501  unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
502  unpickler.persistent_load = persistent_load
503  result = unpickler.load()
504  return result
505 
506  deserialized_objects = {}
507 
508  def maybe_decode_ascii(bytes_str):
509  # When using encoding='bytes' in Py3, some **internal** keys stored as
510  # strings in Py2 are loaded as bytes. This function decodes them with
511  # ascii encoding, one that Py3 uses by default.
512  #
513  # NOTE: This should only be used on internal keys (e.g., `typename` and
514  # `location` in `persistent_load` below!
515  if isinstance(bytes_str, bytes):
516  return bytes_str.decode('ascii')
517  return bytes_str
518 
519  def persistent_load(saved_id):
520  assert isinstance(saved_id, tuple)
521  typename = maybe_decode_ascii(saved_id[0])
522  data = saved_id[1:]
523 
524  if typename == 'module':
525  # Ignore containers that don't have any sources saved
526  if all(data[1:]):
527  _check_container_source(*data)
528  return data[0]
529  elif typename == 'storage':
530  data_type, root_key, location, size, view_metadata = data
531  location = maybe_decode_ascii(location)
532  if root_key not in deserialized_objects:
533  obj = data_type(size)
534  obj._torch_load_uninitialized = True
535  deserialized_objects[root_key] = restore_location(obj, location)
536  storage = deserialized_objects[root_key]
537  if view_metadata is not None:
538  view_key, offset, view_size = view_metadata
539  if view_key not in deserialized_objects:
540  deserialized_objects[view_key] = storage[offset:offset + view_size]
541  return deserialized_objects[view_key]
542  else:
543  return storage
544  else:
545  raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
546 
547  _check_seekable(f)
548  f_should_read_directly = _should_read_directly(f)
549 
550  if f_should_read_directly and f.tell() == 0:
551  # legacy_load requires that f has fileno()
552  # only if offset is zero we can attempt the legacy tar file loader
553  try:
554  return legacy_load(f)
555  except tarfile.TarError:
556  if zipfile.is_zipfile(f):
557  # .zip is used for torch.jit.save and will throw an un-pickling error here
558  raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
559  # if not a tarfile, reset file offset and proceed
560  f.seek(0)
561 
562  magic_number = pickle_module.load(f, **pickle_load_args)
563  if magic_number != MAGIC_NUMBER:
564  raise RuntimeError("Invalid magic number; corrupt file?")
565  protocol_version = pickle_module.load(f, **pickle_load_args)
566  if protocol_version != PROTOCOL_VERSION:
567  raise RuntimeError("Invalid protocol version: %s" % protocol_version)
568 
569  _sys_info = pickle_module.load(f, **pickle_load_args)
570  unpickler = pickle_module.Unpickler(f, **pickle_load_args)
571  unpickler.persistent_load = persistent_load
572  result = unpickler.load()
573 
574  deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
575 
576  offset = f.tell() if f_should_read_directly else None
577  for key in deserialized_storage_keys:
578  assert key in deserialized_objects
579  deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
580  offset = None
581 
582  return result
def is_available()
Definition: __init__.py:45
def device_count()
Definition: __init__.py:341
def typename(o)
Define basic utilities.
Definition: __init__.py:94
def is_storage(obj)
Definition: __init__.py:123