Caffe2 - Python API
A deep learning, cross platform ML framework
__init__.py
1 r"""
2 The torch package contains data structures for multi-dimensional
3 tensors and mathematical operations over these are defined.
4 Additionally, it provides many utilities for efficient serializing of
5 Tensors and arbitrary types, and other useful utilities.
6 
7 It has a CUDA counterpart, that enables you to run your tensor computations
8 on an NVIDIA GPU with compute capability >= 3.0.
9 """
10 
11 import os
12 import sys
13 import platform
14 from ._utils import _import_dotted_name
15 from ._utils_internal import get_file_path, prepare_multiprocessing_environment
16 from .version import __version__
17 from ._six import string_classes as _string_classes
18 
19 __all__ = [
20  'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
21  'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed',
22  'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
23  'no_grad', 'enable_grad', 'rand', 'randn',
24  'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
25  'ShortStorage', 'CharStorage', 'ByteStorage',
26  'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
27  'ShortTensor', 'CharTensor', 'ByteTensor', 'Tensor',
28 ]
29 
30 ################################################################################
31 # Load the extension module
32 ################################################################################
33 
34 # Loading the extension with RTLD_GLOBAL option allows to not link extension
35 # modules against the _C shared object. Their missing THP symbols will be
36 # automatically filled by the dynamic loader.
37 import os as _dl_flags
38 
39 # if we have numpy, it *must* be imported before the call to setdlopenflags()
40 # or there is risk that later c modules will segfault when importing numpy
41 try:
42  import numpy as _np
43 except ImportError:
44  pass
45 
46 if platform.system() == 'Windows':
47  # first get nvToolsExt PATH
48  def get_nvToolsExt_path():
49  NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt')
50 
51  if _dl_flags.path.exists(NVTOOLEXT_HOME):
52  return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64')
53  else:
54  return ''
55 
56  py_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(sys.executable), 'Library', 'bin')
57  th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib')
58 
59  dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']]
60 
61  # then add the path to env
62  _dl_flags.environ['PATH'] = ';'.join(dll_paths)
63 
64 else:
65  # first check if the os package has the required flags
66  if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
67  try:
68  # next try if DLFCN exists
69  import DLFCN as _dl_flags
70  except ImportError:
71  # as a last attempt, use compile-time constants
72  import torch._dl as _dl_flags
73 
74  old_flags = sys.getdlopenflags()
75  sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
76 
77 del _dl_flags
78 
79 from torch._C import *
80 
81 __all__ += [name for name in dir(_C)
82  if name[0] != '_' and
83  not name.endswith('Base')]
84 
85 if platform.system() != 'Windows':
86  sys.setdlopenflags(old_flags)
87  del old_flags
88 
89 ################################################################################
90 # Define basic utilities
91 ################################################################################
92 
93 
94 def typename(o):
95  if isinstance(o, torch.Tensor):
96  return o.type()
97 
98  module = ''
99  class_name = ''
100  if hasattr(o, '__module__') and o.__module__ != 'builtins' \
101  and o.__module__ != '__builtin__' and o.__module__ is not None:
102  module = o.__module__ + '.'
103 
104  if hasattr(o, '__qualname__'):
105  class_name = o.__qualname__
106  elif hasattr(o, '__name__'):
107  class_name = o.__name__
108  else:
109  class_name = o.__class__.__name__
110 
111  return module + class_name
112 
113 
114 def is_tensor(obj):
115  r"""Returns True if `obj` is a PyTorch tensor.
116 
117  Args:
118  obj (Object): Object to test
119  """
120  return isinstance(obj, torch.Tensor)
121 
122 
123 def is_storage(obj):
124  r"""Returns True if `obj` is a PyTorch storage object.
125 
126  Args:
127  obj (Object): Object to test
128  """
129  return type(obj) in _storage_classes
130 
131 
132 def set_default_tensor_type(t):
133  r"""Sets the default ``torch.Tensor`` type to floating point tensor type
134  :attr:`t`. This type will also be used as default floating point type for
135  type inference in :func:`torch.tensor`.
136 
137  The default floating point tensor type is initially ``torch.FloatTensor``.
138 
139  Args:
140  t (type or string): the floating point tensor type or its name
141 
142  Example::
143 
144  >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
145  torch.float32
146  >>> torch.set_default_tensor_type(torch.DoubleTensor)
147  >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
148  torch.float64
149 
150  """
151  if isinstance(t, _string_classes):
152  t = _import_dotted_name(t)
153  _C._set_default_tensor_type(t)
154 
155 
156 def set_default_dtype(d):
157  r"""Sets the default floating point dtype to :attr:`d`. This type will be
158  used as default floating point type for type inference in
159  :func:`torch.tensor`.
160 
161  The default floating point dtype is initially ``torch.float32``.
162 
163  Args:
164  d (:class:`torch.dtype`): the floating point dtype to make the default
165 
166  Example::
167 
168  >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
169  torch.float32
170  >>> torch.set_default_dtype(torch.float64)
171  >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
172  torch.float64
173 
174  """
175  _C._set_default_dtype(d)
176 
177 # If you edit these imports, please update torch/__init__.py.in as well
178 from .random import set_rng_state, get_rng_state, manual_seed, initial_seed
179 from .serialization import save, load
180 from ._tensor_str import set_printoptions
181 
182 ################################################################################
183 # Define Storage and Tensor classes
184 ################################################################################
185 
186 from .tensor import Tensor
187 from .storage import _StorageBase
188 
189 
190 class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
191  pass
192 
193 
194 class FloatStorage(_C.FloatStorageBase, _StorageBase):
195  pass
196 
197 
198 class HalfStorage(_C.HalfStorageBase, _StorageBase):
199  pass
200 
201 
202 class LongStorage(_C.LongStorageBase, _StorageBase):
203  pass
204 
205 
206 class IntStorage(_C.IntStorageBase, _StorageBase):
207  pass
208 
209 
210 class ShortStorage(_C.ShortStorageBase, _StorageBase):
211  pass
212 
213 
214 class CharStorage(_C.CharStorageBase, _StorageBase):
215  pass
216 
217 
218 class ByteStorage(_C.ByteStorageBase, _StorageBase):
219  pass
220 
221 
222 class BoolStorage(_C.BoolStorageBase, _StorageBase):
223  pass
224 
225 _storage_classes = {
226  DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
227  CharStorage, ByteStorage, HalfStorage, BoolStorage
228 }
229 
230 # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
231 _tensor_classes = set()
232 
233 
234 ################################################################################
235 # Initialize extension
236 ################################################################################
237 
238 def manager_path():
239  if platform.system() == 'Windows':
240  return b""
241  path = get_file_path('torch', 'bin', 'torch_shm_manager')
242  prepare_multiprocessing_environment(get_file_path('torch'))
243  if not os.path.exists(path):
244  raise RuntimeError("Unable to find torch_shm_manager at " + path)
245  return path.encode('utf-8')
246 
247 
248 # Shared memory manager needs to know the exact location of manager executable
249 _C._initExtension(manager_path())
250 del manager_path
251 
252 for name in dir(_C._VariableFunctions):
253  if name.startswith('__'):
254  continue
255  globals()[name] = getattr(_C._VariableFunctions, name)
256 
257 ################################################################################
258 # Import interface functions defined in Python
259 ################################################################################
260 
261 # needs to be after the above ATen bindings so we can overwrite from Python side
262 from .functional import *
263 
264 
265 ################################################################################
266 # Remove unnecessary members
267 ################################################################################
268 
269 del DoubleStorageBase
270 del FloatStorageBase
271 del LongStorageBase
272 del IntStorageBase
273 del ShortStorageBase
274 del CharStorageBase
275 del ByteStorageBase
276 del BoolStorageBase
277 
278 ################################################################################
279 # Import most common subpackages
280 ################################################################################
281 
282 import torch.cuda
283 import torch.autograd
284 from torch.autograd import no_grad, enable_grad, set_grad_enabled
285 import torch.nn
286 import torch.optim
288 import torch.sparse
290 import torch.onnx
291 import torch.jit
292 import torch.random
293 import torch.distributions
294 import torch.testing
295 import torch.backends.cuda
296 import torch.backends.mkl
298 
299 _C._init_names(list(torch._storage_classes))
300 
301 # attach docstrings to torch and tensor functions
302 from . import _torch_docs, _tensor_docs, _storage_docs
303 del _torch_docs, _tensor_docs, _storage_docs
304 
305 
306 def compiled_with_cxx11_abi():
307  r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
308  return _C._GLIBCXX_USE_CXX11_ABI
309 
310 
311 # Import the ops "namespace"
312 from torch._ops import ops