5 from torch._C import default_generator
8 def set_rng_state(new_state):
9 r"""Sets the random number generator state. 12 new_state (torch.ByteTensor): The desired state 14 default_generator.set_state(new_state)
18 r"""Returns the random number generator state as a `torch.ByteTensor`.""" 19 return default_generator.get_state()
22 def manual_seed(seed):
23 r"""Sets the seed for generating random numbers. Returns a 24 `torch._C.Generator` object. 27 seed (int): The desired seed. 32 if not torch.cuda._in_bad_fork:
33 torch.cuda.manual_seed_all(seed)
35 return default_generator.manual_seed(seed)
39 r"""Returns the initial seed for generating random numbers as a 42 return default_generator.initial_seed()
45 _fork_rng_warned_already =
False 48 @contextlib.contextmanager
49 def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
51 Forks the RNG, so that when you return, the RNG is reset 52 to the state that it was previously in. 55 devices (iterable of CUDA IDs): CUDA devices for which to fork 56 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 57 on all devices, but will emit a warning if your machine has a lot 58 of devices, since this function will run very slowly in that case. 59 If you explicitly specify devices, this warning will be suppressed 60 enabled (bool): if ``False``, the RNG is not forked. This is a convenience 61 argument for easily disabling the context manager without having 62 to delete it and unindent your Python code under it. 66 global _fork_rng_warned_already
78 if num_devices > 1
and not _fork_rng_warned_already:
80 (
"CUDA reports that you have {num_devices} available devices, and you " 81 "have used {caller} without explicitly specifying which devices are being used. " 82 "For safety, we initialize *every* CUDA device by default, which " 83 "can be quite slow if you have a lot of GPUs. If you know that you are only " 84 "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES " 85 "or the '{devices_kw}' keyword argument of {caller} with the set of devices " 86 "you are actually using. For example, if you are using CPU only, " 87 "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using " 88 "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize " 89 "all devices and suppress this warning, set the '{devices_kw}' keyword argument " 90 "to `range(torch.cuda.device_count())`." 91 ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
92 _fork_rng_warned_already =
True 93 devices = list(range(num_devices))
97 devices = list(devices)
99 cpu_rng_state = torch.get_rng_state()
101 for device
in devices:
103 gpu_rng_states.append(torch.cuda.get_rng_state())
108 torch.set_rng_state(cpu_rng_state)
109 for device, gpu_rng_state
in zip(devices, gpu_rng_states):
111 torch.cuda.set_rng_state(gpu_rng_state)