5 def _get_device_index(device, optional=False):
6 r"""Gets the device index from :attr:`device`, which can be a torch.device 7 object, a Python integer, or ``None``. 9 If :attr:`device` is a torch.device object, returns the device index if it 10 is a CUDA device. Note that for a CUDA device without a specified index, 11 i.e., ``torch.device('cuda')``, this will return the current default CUDA 12 device if :attr:`optional` is ``True``. 14 If :attr:`device` is a Python integer, it is returned as is. 16 If :attr:`device` is ``None``, this will return the current default CUDA 17 device if :attr:`optional` is ``True``. 19 if isinstance(device, torch._six.string_classes):
20 device = torch.device(device)
21 if isinstance(device, torch.device):
22 dev_type = device.type
23 if device.type !=
'cuda':
24 raise ValueError(
'Expected a cuda device, but got: {}'.format(device))
25 device_idx = device.index
28 if device_idx
is None:
33 raise ValueError(
'Expected a cuda device with a specified index ' 34 'or an integer, but got: '.format(device))