13 from contextlib
import closing, contextmanager
14 from ._utils
import _import_dotted_name
15 from ._six
import string_classes
as _string_classes
16 if sys.version_info[0] == 2:
17 import cPickle
as pickle
24 LONG_SIZE = struct.Struct(
'=l').size
25 INT_SIZE = struct.Struct(
'=i').size
26 SHORT_SIZE = struct.Struct(
'=h').size
28 MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
29 PROTOCOL_VERSION = 1001
30 STORAGE_KEY_SEPARATOR =
',' 39 path = tempfile.mkdtemp()
44 _package_registry = []
47 def register_package(priority, tagger, deserializer):
48 queue_elem = (priority, tagger, deserializer)
49 _package_registry.append(queue_elem)
50 _package_registry.sort()
54 if type(obj).__module__ ==
'torch':
59 if type(obj).__module__ ==
'torch.cuda':
60 return 'cuda:' + str(obj.get_device())
63 def _cpu_deserialize(obj, location):
68 def validate_cuda_device(location):
69 if isinstance(location, torch.device):
70 location = str(location)
71 if not isinstance(location, _string_classes):
72 raise ValueError(
"location should be a string or torch.device")
73 if location[5:] ==
'':
76 device = max(int(location[5:]), 0)
79 raise RuntimeError(
'Attempting to deserialize object on a CUDA ' 80 'device but torch.cuda.is_available() is False. ' 81 'If you are running on a CPU-only machine, ' 82 'please use torch.load with map_location=\'cpu\' ' 83 'to map your storages to the CPU.')
85 raise RuntimeError(
'Attempting to deserialize object on CUDA device ' 86 '{} but torch.cuda.device_count() is {}. Please use ' 87 'torch.load with map_location to map your storages ' 88 'to an existing device.'.format(
93 def _cuda_deserialize(obj, location):
94 if location.startswith(
'cuda'):
95 device = validate_cuda_device(location)
96 if getattr(obj,
"_torch_load_uninitialized",
False):
97 storage_type = getattr(
torch.cuda, type(obj).__name__)
99 return storage_type(obj.size())
101 return obj.cuda(device)
104 register_package(10, _cpu_tag, _cpu_deserialize)
105 register_package(20, _cuda_tag, _cuda_deserialize)
108 def location_tag(storage):
109 for _, tagger, _
in _package_registry:
110 location = tagger(storage)
113 raise RuntimeError(
"don't know how to determine data location of " +
117 def default_restore_location(storage, location):
118 for _, _, fn
in _package_registry:
119 result = fn(storage, location)
120 if result
is not None:
122 raise RuntimeError(
"don't know how to restore data location of " +
127 def normalize_storage_type(storage_type):
128 return getattr(torch, storage_type.__name__)
131 def storage_to_tensor_type(storage):
132 storage_type = type(storage)
133 module = _import_dotted_name(storage_type.__module__)
134 return getattr(module, storage_type.__name__.replace(
'Storage',
'Tensor'))
137 def _with_file_like(f, mode, body):
139 Executes a body function with a file object for f, opening 140 it in 'mode' if it is a string filename. 143 if isinstance(f, str)
or \
144 (sys.version_info[0] == 2
and isinstance(f, unicode))
or \
145 (sys.version_info[0] == 3
and isinstance(f, pathlib.Path)):
155 def _is_compressed_file(f):
156 compress_modules = [
'gzip']
158 return f.__module__
in compress_modules
159 except AttributeError:
163 def _should_read_directly(f):
165 Checks if f is a file that should be read directly. It should be read 166 directly if it is backed by a real file (has a fileno) and is not a 167 a compressed file (e.g. gzip) 169 if _is_compressed_file(f):
172 return f.fileno() >= 0
173 except io.UnsupportedOperation:
175 except AttributeError:
179 def _check_seekable(f):
181 def raise_err_msg(patterns, e):
184 msg = (str(e) +
". You can only torch.load from a file that is seekable." +
185 " Please pre-load the data into a buffer like io.BytesIO and" +
186 " try to load from it instead.")
193 except (io.UnsupportedOperation, AttributeError)
as e:
194 raise_err_msg([
"seek",
"tell"], e)
197 def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
198 """Saves an object to a disk file. 200 See also: :ref:`recommend-saving-models` 204 f: a file-like object (has to implement write and flush) or a string 205 containing a file name 206 pickle_module: module used for pickling metadata and objects 207 pickle_protocol: can be specified to override the default protocol 210 If you are using Python 2, torch.save does NOT support StringIO.StringIO 211 as a valid file-like object. This is because the write method should return 212 the number of bytes written; StringIO.write() does not do this. 214 Please use something like io.BytesIO instead. 218 >>> x = torch.tensor([0, 1, 2, 3, 4]) 219 >>> torch.save(x, 'tensor.pt') 220 >>> # Save to io.BytesIO buffer 221 >>> buffer = io.BytesIO() 222 >>> torch.save(x, buffer) 224 return _with_file_like(f,
"wb",
lambda f: _save(obj, f, pickle_module, pickle_protocol))
227 def _save(obj, f, pickle_module, pickle_protocol):
228 if sys.version_info[0] == 2:
230 if isinstance(f, StringIO.StringIO):
231 msg = (
'torch.save received unsupported StringIO.StringIO file object, whose ' 232 'write method does not return the number of bytes written. ' 233 'Please use something like io.BytesIO for torch.save instead.')
234 raise RuntimeError(msg)
237 serialized_container_types = {}
238 serialized_storages = {}
240 def persistent_id(obj):
246 if isinstance(obj, type)
and issubclass(obj, nn.Module):
247 if obj
in serialized_container_types:
249 serialized_container_types[obj] =
True 250 source_file = source =
None 252 source_file = inspect.getsourcefile(obj)
253 source = inspect.getsource(obj)
255 warnings.warn(
"Couldn't retrieve source code for container of " 256 "type " + obj.__name__ +
". It won't be checked " 257 "for correctness upon loading.")
258 return (
'module', obj, source_file, source)
260 storage_type = normalize_storage_type(type(obj))
264 obj_key = str(obj._cdata)
265 location = location_tag(obj)
266 serialized_storages[obj_key] = obj
267 is_view = obj._cdata != obj._cdata
269 view_metadata = (str(obj._cdata), offset, obj.size())
283 protocol_version=PROTOCOL_VERSION,
284 little_endian=sys.byteorder ==
'little',
292 pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
293 pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
294 pickle_module.dump(sys_info, f, protocol=pickle_protocol)
295 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
296 pickler.persistent_id = persistent_id
299 serialized_storage_keys = sorted(serialized_storages.keys())
300 pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
302 for key
in serialized_storage_keys:
303 serialized_storages[key]._write_file(f, _should_read_directly(f))
306 def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
307 """Loads an object saved with :func:`torch.save` from a file. 309 :meth:`torch.load` uses Python's unpickling facilities but treats storages, 310 which underlie tensors, specially. They are first deserialized on the 311 CPU and are then moved to the device they were saved from. If this fails 312 (e.g. because the run time system doesn't have certain devices), an exception 313 is raised. However, storages can be dynamically remapped to an alternative 314 set of devices using the `map_location` argument. 316 If `map_location` is a callable, it will be called once for each serialized 317 storage with two arguments: storage and location. The storage argument 318 will be the initial deserialization of the storage, residing on the CPU. 319 Each serialized storage has a location tag associated with it which 320 identifies the device it was saved from, and this tag is the second 321 argument passed to map_location. The builtin location tags are `'cpu'` for 322 CPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors. 323 `map_location` should return either None or a storage. If `map_location` returns 324 a storage, it will be used as the final deserialized object, already moved to 325 the right device. Otherwise, :math:`torch.load` will fall back to the default 326 behavior, as if `map_location` wasn't specified. 328 If `map_location` is a string, it should be a device tag, where all tensors 331 Otherwise, if `map_location` is a dict, it will be used to remap location tags 332 appearing in the file (keys), to ones that specify where to put the 335 User extensions can register their own location tags and tagging and 336 deserialization methods using `register_package`. 339 f: a file-like object (has to implement read, readline, tell, and seek), 340 or a string containing a file name 341 map_location: a function, torch.device, string or a dict specifying how to remap storage 343 pickle_module: module used for unpickling metadata and objects (has to 344 match the pickle_module used to serialize file) 345 pickle_load_args: optional keyword arguments passed over to 346 ``pickle_module.load`` and ``pickle_module.Unpickler``, e.g., 350 When you call :meth:`torch.load()` on a file which contains GPU tensors, those tensors 351 will be loaded to GPU by default. You can call `torch.load(.., map_location='cpu')` 352 and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. 355 In Python 3, when loading files saved by Python 2, you may encounter 356 ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``. This is 357 caused by the difference of handling in byte strings in Python2 and 358 Python 3. You may use extra ``encoding`` keyword argument to specify how 359 these objects should be loaded, e.g., ``encoding='latin1'`` decodes them 360 to strings using ``latin1`` encoding, and ``encoding='bytes'`` keeps them 361 as byte arrays which can be decoded later with ``byte_array.decode(...)``. 364 >>> torch.load('tensors.pt') 365 # Load all tensors onto the CPU 366 >>> torch.load('tensors.pt', map_location=torch.device('cpu')) 367 # Load all tensors onto the CPU, using a function 368 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) 369 # Load all tensors onto GPU 1 370 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) 371 # Map tensors from GPU 1 to GPU 0 372 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) 373 # Load tensor from io.BytesIO object 374 >>> with open('tensor.pt', 'rb') as f: 375 buffer = io.BytesIO(f.read()) 376 >>> torch.load(buffer) 379 if isinstance(f, str)
or \
380 (sys.version_info[0] == 2
and isinstance(f, unicode))
or \
381 (sys.version_info[0] == 3
and isinstance(f, pathlib.Path)):
385 return _load(f, map_location, pickle_module, **pickle_load_args)
391 def _load(f, map_location, pickle_module, **pickle_load_args):
392 deserialized_objects = {}
394 if map_location
is None:
395 restore_location = default_restore_location
396 elif isinstance(map_location, dict):
397 def restore_location(storage, location):
398 location = map_location.get(location, location)
399 return default_restore_location(storage, location)
400 elif isinstance(map_location, _string_classes):
401 def restore_location(storage, location):
402 return default_restore_location(storage, map_location)
403 elif isinstance(map_location, torch.device):
404 def restore_location(storage, location):
405 return default_restore_location(storage, str(map_location))
407 def restore_location(storage, location):
408 result = map_location(storage, location)
410 result = default_restore_location(storage, location)
413 def _check_container_source(container_type, source_file, original_source):
415 current_source = inspect.getsource(container_type)
417 warnings.warn(
"Couldn't retrieve source code for container of " 418 "type " + container_type.__name__ +
". It won't be checked " 419 "for correctness upon loading.")
421 if original_source != current_source:
422 if container_type.dump_patches:
423 file_name = container_type.__name__ +
'.patch' 424 diff = difflib.unified_diff(current_source.split(
'\n'),
425 original_source.split(
'\n'),
427 source_file, lineterm=
"")
428 lines =
'\n'.join(diff)
430 with open(file_name,
'a+')
as f:
431 file_size = f.seek(0, 2)
435 elif file_size != len(lines)
or f.read() != lines:
437 msg = (
"Saved a reverse patch to " + file_name +
". " 438 "Run `patch -p0 < " + file_name +
"` to revert your " 441 msg = (
"Tried to save a patch, but couldn't create a " 442 "writable file " + file_name +
". Make sure it " 443 "doesn't exist and your working directory is " 446 msg = (
"you can retrieve the original source code by " 447 "accessing the object's source attribute or set " 448 "`torch.nn.Module.dump_patches = True` and use the " 449 "patch tool to revert the changes.")
450 msg = (
"source code of class '{}' has changed. {}" 452 warnings.warn(msg, SourceChangeWarning)
455 deserialized_objects = {}
457 def persistent_load(saved_id):
458 if isinstance(saved_id, tuple):
460 if all(saved_id[1:]):
461 _check_container_source(*saved_id)
463 return deserialized_objects[int(saved_id)]
465 with closing(tarfile.open(fileobj=f, mode=
'r:', format=tarfile.PAX_FORMAT))
as tar, \
468 tar.extract(
'storages', path=tmpdir)
469 with open(os.path.join(tmpdir,
'storages'),
'rb', 0)
as f:
470 num_storages = pickle_module.load(f, **pickle_load_args)
471 for i
in range(num_storages):
472 args = pickle_module.load(f, **pickle_load_args)
473 key, location, storage_type = args
474 obj = storage_type._new_with_file(f)
475 obj = restore_location(obj, location)
476 deserialized_objects[key] = obj
478 storage_views = pickle_module.load(f, **pickle_load_args)
479 for target_cdata, root_cdata, offset, size
in storage_views:
480 root = deserialized_objects[root_cdata]
481 deserialized_objects[target_cdata] = root[offset:offset + size]
483 tar.extract(
'tensors', path=tmpdir)
484 with open(os.path.join(tmpdir,
'tensors'),
'rb', 0)
as f:
485 num_tensors = pickle_module.load(f, **pickle_load_args)
486 for _
in range(num_tensors):
487 args = pickle_module.load(f, **pickle_load_args)
488 key, storage_id, original_tensor_type = args
489 storage = deserialized_objects[storage_id]
490 tensor_type = storage_to_tensor_type(storage)
491 ndim, = struct.unpack(
'<i', f.read(4))
494 size = struct.unpack(
'<{}q'.format(ndim), f.read(8 * ndim))
495 stride = struct.unpack(
'<{}q'.format(ndim), f.read(8 * ndim))
496 storage_offset, = struct.unpack(
'<q', f.read(8))
497 tensor = tensor_type().set_(storage, storage_offset, size, stride)
498 deserialized_objects[key] = tensor
500 pickle_file = tar.extractfile(
'pickle')
501 unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
502 unpickler.persistent_load = persistent_load
503 result = unpickler.load()
506 deserialized_objects = {}
508 def maybe_decode_ascii(bytes_str):
515 if isinstance(bytes_str, bytes):
516 return bytes_str.decode(
'ascii')
519 def persistent_load(saved_id):
520 assert isinstance(saved_id, tuple)
521 typename = maybe_decode_ascii(saved_id[0])
524 if typename ==
'module':
527 _check_container_source(*data)
529 elif typename ==
'storage':
530 data_type, root_key, location, size, view_metadata = data
531 location = maybe_decode_ascii(location)
532 if root_key
not in deserialized_objects:
533 obj = data_type(size)
534 obj._torch_load_uninitialized =
True 535 deserialized_objects[root_key] = restore_location(obj, location)
536 storage = deserialized_objects[root_key]
537 if view_metadata
is not None:
538 view_key, offset, view_size = view_metadata
539 if view_key
not in deserialized_objects:
540 deserialized_objects[view_key] = storage[offset:offset + view_size]
541 return deserialized_objects[view_key]
545 raise RuntimeError(
"Unknown saved id type: %s" % saved_id[0])
548 f_should_read_directly = _should_read_directly(f)
550 if f_should_read_directly
and f.tell() == 0:
554 return legacy_load(f)
555 except tarfile.TarError:
556 if zipfile.is_zipfile(f):
558 raise RuntimeError(
"{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
562 magic_number = pickle_module.load(f, **pickle_load_args)
563 if magic_number != MAGIC_NUMBER:
564 raise RuntimeError(
"Invalid magic number; corrupt file?")
565 protocol_version = pickle_module.load(f, **pickle_load_args)
566 if protocol_version != PROTOCOL_VERSION:
567 raise RuntimeError(
"Invalid protocol version: %s" % protocol_version)
569 _sys_info = pickle_module.load(f, **pickle_load_args)
570 unpickler = pickle_module.Unpickler(f, **pickle_load_args)
571 unpickler.persistent_load = persistent_load
572 result = unpickler.load()
574 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
576 offset = f.tell()
if f_should_read_directly
else None 577 for key
in deserialized_storage_keys:
578 assert key
in deserialized_objects
579 deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
def typename(o)
Define basic utilities.