Caffe2 - Python API
A deep learning, cross platform ML framework
random.py
1 import torch
2 import contextlib
3 import warnings
4 
5 from torch._C import default_generator
6 
7 
8 def set_rng_state(new_state):
9  r"""Sets the random number generator state.
10 
11  Args:
12  new_state (torch.ByteTensor): The desired state
13  """
14  default_generator.set_state(new_state)
15 
16 
17 def get_rng_state():
18  r"""Returns the random number generator state as a `torch.ByteTensor`."""
19  return default_generator.get_state()
20 
21 
22 def manual_seed(seed):
23  r"""Sets the seed for generating random numbers. Returns a
24  `torch._C.Generator` object.
25 
26  Args:
27  seed (int): The desired seed.
28  """
29  seed = int(seed)
30  import torch.cuda
31 
32  if not torch.cuda._in_bad_fork:
33  torch.cuda.manual_seed_all(seed)
34 
35  return default_generator.manual_seed(seed)
36 
37 
38 def initial_seed():
39  r"""Returns the initial seed for generating random numbers as a
40  Python `long`.
41  """
42  return default_generator.initial_seed()
43 
44 
45 _fork_rng_warned_already = False
46 
47 
48 @contextlib.contextmanager
49 def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
50  """
51  Forks the RNG, so that when you return, the RNG is reset
52  to the state that it was previously in.
53 
54  Arguments:
55  devices (iterable of CUDA IDs): CUDA devices for which to fork
56  the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
57  on all devices, but will emit a warning if your machine has a lot
58  of devices, since this function will run very slowly in that case.
59  If you explicitly specify devices, this warning will be suppressed
60  enabled (bool): if ``False``, the RNG is not forked. This is a convenience
61  argument for easily disabling the context manager without having
62  to delete it and unindent your Python code under it.
63  """
64 
65  import torch.cuda
66  global _fork_rng_warned_already
67 
68  # Internal arguments:
69  # _caller: the function which called fork_rng, which the user used
70  # _devices_kw: the devices keyword of _caller
71 
72  if not enabled:
73  yield
74  return
75 
76  if devices is None:
77  num_devices = torch.cuda.device_count()
78  if num_devices > 1 and not _fork_rng_warned_already:
79  warnings.warn(
80  ("CUDA reports that you have {num_devices} available devices, and you "
81  "have used {caller} without explicitly specifying which devices are being used. "
82  "For safety, we initialize *every* CUDA device by default, which "
83  "can be quite slow if you have a lot of GPUs. If you know that you are only "
84  "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
85  "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
86  "you are actually using. For example, if you are using CPU only, "
87  "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
88  "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
89  "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
90  "to `range(torch.cuda.device_count())`."
91  ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
92  _fork_rng_warned_already = True
93  devices = list(range(num_devices))
94  else:
95  # Protect against user passing us a generator; we need to traverse this
96  # multiple times but a generator will be exhausted upon first traversal
97  devices = list(devices)
98 
99  cpu_rng_state = torch.get_rng_state()
100  gpu_rng_states = []
101  for device in devices:
102  with torch.cuda.device(device):
103  gpu_rng_states.append(torch.cuda.get_rng_state())
104 
105  try:
106  yield
107  finally:
108  torch.set_rng_state(cpu_rng_state)
109  for device, gpu_rng_state in zip(devices, gpu_rng_states):
110  with torch.cuda.device(device):
111  torch.cuda.set_rng_state(gpu_rng_state)
def device_count()
Definition: __init__.py:341