2 The torch package contains data structures for multi-dimensional 3 tensors and mathematical operations over these are defined. 4 Additionally, it provides many utilities for efficient serializing of 5 Tensors and arbitrary types, and other useful utilities. 7 It has a CUDA counterpart, that enables you to run your tensor computations 8 on an NVIDIA GPU with compute capability >= 3.0. 14 from ._utils
import _import_dotted_name
15 from ._utils_internal
import get_file_path, prepare_multiprocessing_environment
16 from .version
import __version__
17 from ._six
import string_classes
as _string_classes
20 'typename',
'is_tensor',
'is_storage',
'set_default_tensor_type',
21 'set_rng_state',
'get_rng_state',
'manual_seed',
'initial_seed',
22 'save',
'load',
'set_printoptions',
'chunk',
'split',
'stack',
'matmul',
23 'no_grad',
'enable_grad',
'rand',
'randn',
24 'DoubleStorage',
'FloatStorage',
'LongStorage',
'IntStorage',
25 'ShortStorage',
'CharStorage',
'ByteStorage',
26 'DoubleTensor',
'FloatTensor',
'LongTensor',
'IntTensor',
27 'ShortTensor',
'CharTensor',
'ByteTensor',
'Tensor',
37 import os
as _dl_flags
46 if platform.system() ==
'Windows':
48 def get_nvToolsExt_path():
49 NVTOOLEXT_HOME = _dl_flags.getenv(
'NVTOOLSEXT_PATH',
'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt')
51 if _dl_flags.path.exists(NVTOOLEXT_HOME):
52 return _dl_flags.path.join(NVTOOLEXT_HOME,
'bin',
'x64')
56 py_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(sys.executable),
'Library',
'bin')
57 th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__),
'lib')
59 dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ[
'PATH']]
62 _dl_flags.environ[
'PATH'] =
';'.join(dll_paths)
66 if not hasattr(_dl_flags,
'RTLD_GLOBAL')
or not hasattr(_dl_flags,
'RTLD_LAZY'):
69 import DLFCN
as _dl_flags
72 import torch._dl
as _dl_flags
74 old_flags = sys.getdlopenflags()
75 sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
81 __all__ += [name
for name
in dir(_C)
83 not name.endswith(
'Base')]
85 if platform.system() !=
'Windows':
86 sys.setdlopenflags(old_flags)
95 if isinstance(o, torch.Tensor):
100 if hasattr(o,
'__module__')
and o.__module__ !=
'builtins' \
101 and o.__module__ !=
'__builtin__' and o.__module__
is not None:
102 module = o.__module__ +
'.' 104 if hasattr(o,
'__qualname__'):
105 class_name = o.__qualname__
106 elif hasattr(o,
'__name__'):
107 class_name = o.__name__
109 class_name = o.__class__.__name__
111 return module + class_name
115 r"""Returns True if `obj` is a PyTorch tensor. 118 obj (Object): Object to test 120 return isinstance(obj, torch.Tensor)
124 r"""Returns True if `obj` is a PyTorch storage object. 127 obj (Object): Object to test 129 return type(obj)
in _storage_classes
132 def set_default_tensor_type(t):
133 r"""Sets the default ``torch.Tensor`` type to floating point tensor type 134 :attr:`t`. This type will also be used as default floating point type for 135 type inference in :func:`torch.tensor`. 137 The default floating point tensor type is initially ``torch.FloatTensor``. 140 t (type or string): the floating point tensor type or its name 144 >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 146 >>> torch.set_default_tensor_type(torch.DoubleTensor) 147 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor 151 if isinstance(t, _string_classes):
152 t = _import_dotted_name(t)
153 _C._set_default_tensor_type(t)
156 def set_default_dtype(d):
157 r"""Sets the default floating point dtype to :attr:`d`. This type will be 158 used as default floating point type for type inference in 159 :func:`torch.tensor`. 161 The default floating point dtype is initially ``torch.float32``. 164 d (:class:`torch.dtype`): the floating point dtype to make the default 168 >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 170 >>> torch.set_default_dtype(torch.float64) 171 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor 175 _C._set_default_dtype(d)
178 from .random
import set_rng_state, get_rng_state, manual_seed, initial_seed
179 from .serialization
import save, load
180 from ._tensor_str
import set_printoptions
186 from .tensor
import Tensor
187 from .storage
import _StorageBase
194 class FloatStorage(_C.FloatStorageBase, _StorageBase):
226 DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
227 CharStorage, ByteStorage, HalfStorage, BoolStorage
231 _tensor_classes = set()
239 if platform.system() ==
'Windows':
241 path = get_file_path(
'torch',
'bin',
'torch_shm_manager')
242 prepare_multiprocessing_environment(get_file_path(
'torch'))
243 if not os.path.exists(path):
244 raise RuntimeError(
"Unable to find torch_shm_manager at " + path)
245 return path.encode(
'utf-8')
249 _C._initExtension(manager_path())
252 for name
in dir(_C._VariableFunctions):
253 if name.startswith(
'__'):
255 globals()[name] = getattr(_C._VariableFunctions, name)
262 from .functional
import *
269 del DoubleStorageBase
299 _C._init_names(list(torch._storage_classes))
302 from .
import _torch_docs, _tensor_docs, _storage_docs
303 del _torch_docs, _tensor_docs, _storage_docs
306 def compiled_with_cxx11_abi():
307 r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1""" 308 return _C._GLIBCXX_USE_CXX11_ABI