7 from multiprocessing.reduction
import ForkingPickler
15 import multiprocessing.resource_sharer
21 r"""A weak reference to a Storage. 23 The cdata member is a Python number containing the integer representation of 24 the Storage pointer.""" 26 def __init__(self, storage):
27 self.
cdata = storage._weak_ref()
33 return torch.Storage._expired(self.
cdata)
40 """dictionary from multiprocessing handles to StorageWeakRef""" 46 self.
lock = threading.Lock()
48 def __setitem__(self, key, storage_ref):
49 dict.__setitem__(self, key, storage_ref)
50 if len(self) > self.
limit:
53 def free_dead_references(self):
58 for key, storage_ref
in list(self.items()):
59 if storage_ref.expired():
63 self.
limit = max(128, live * 2)
70 def rebuild_event(device, handle):
71 return torch.cuda.Event.from_ipc_handle(device, handle)
74 def reduce_event(event):
75 handle = event.ipc_handle()
76 return (rebuild_event, (event.device, handle))
79 def rebuild_tensor(cls, storage, metadata):
80 storage_offset, size, stride, requires_grad = metadata
84 t.requires_grad = requires_grad
88 def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
89 storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
92 if storage_handle
is None or storage_size_bytes == 0:
93 storage = storage_cls(0)
95 storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
98 storage = storage_cls._new_shared_cuda(
102 storage_offset_bytes)
103 shared_cache[(storage_handle, storage_offset_bytes)] =
StorageWeakRef(storage)
108 t.requires_grad = requires_grad
112 def reduce_tensor(tensor):
113 storage = tensor.storage()
115 if tensor.requires_grad
and not tensor.is_leaf:
116 raise RuntimeError(
"Cowardly refusing to serialize non-leaf tensor which requires_grad, " 117 "since autograd does not support crossing process boundaries. " 118 "If you just want to transfer the data, call detach() on the tensor " 119 "before serializing (e.g., putting it on the queue).")
214 (device, handle, storage_size_bytes, storage_offset_bytes) = storage._share_cuda_()
215 tensor_offset = tensor.storage_offset()
221 return (rebuild_cuda_tensor,
230 storage_offset_bytes,
231 tensor.requires_grad))
234 metadata = (tensor.storage_offset(), tensor.size(), tensor.stride(), tensor.requires_grad)
235 return (rebuild_tensor, (type(tensor), storage, metadata))
243 return (stat.st_ino, stat.st_dev)
246 def storage_from_cache(cls, key):
247 storage_ref = shared_cache.get(key)
248 if storage_ref
is None:
250 return cls._new_with_weak_ptr(storage_ref.cdata)
253 def rebuild_storage_fd(cls, df, size):
254 if sys.version_info[0] == 2:
255 fd = multiprocessing.reduction.rebuild_handle(df)
259 storage = storage_from_cache(cls, fd_id(fd))
260 if storage
is not None:
262 storage = cls._new_shared_fd(fd, size)
269 def rebuild_storage_filename(cls, manager, handle, size):
270 storage = storage_from_cache(cls, handle)
271 if storage
is not None:
272 return storage._shared_decref()
273 storage = cls._new_shared_filename(manager, handle, size)
275 return storage._shared_decref()
278 def rebuild_storage_empty(cls):
282 def reduce_storage(storage):
283 from .
import get_sharing_strategy
285 raise RuntimeError(
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
287 metadata = storage._share_filename_()
288 cache_key = metadata[1]
289 rebuild = rebuild_storage_filename
290 storage._shared_incref()
291 elif storage.size() == 0:
294 return (rebuild_storage_empty, (type(storage),))
296 fd, size = storage._share_fd_()
297 if sys.version_info[0] == 2:
298 df = multiprocessing.reduction.reduce_handle(fd)
300 df = multiprocessing.reduction.DupFd(fd)
301 cache_key = fd_id(fd)
302 metadata = (df, size)
303 rebuild = rebuild_storage_fd
306 return (rebuild, (type(storage),) + metadata)
309 def init_reductions():
310 ForkingPickler.register(torch.cuda.Event, reduce_event)
312 for t
in torch._storage_classes:
313 ForkingPickler.register(t, reduce_storage)
315 for t
in torch._tensor_classes:
316 ForkingPickler.register(t, reduce_tensor)
319 ForkingPickler.register(torch.Tensor, reduce_tensor)
def _rebuild_tensor(storage, storage_offset, size, stride)
def get_sharing_strategy()
def warn_if_has_hooks(tensor)
def free_dead_references(self)