1 r"""This file is allowed to initialize CUDA context when imported.""" 5 from common_utils
import TEST_WITH_ROCM, TEST_NUMBA
10 CUDA_DEVICE = TEST_CUDA
and torch.device(
"cuda:0")
17 TEST_NUMBA_CUDA = numba.cuda.is_available()
19 TEST_NUMBA_CUDA =
False 23 __cuda_ctx_rng_initialized =
False 27 def initialize_cuda_context_rng():
28 global __cuda_ctx_rng_initialized
29 assert TEST_CUDA,
'CUDA must be available when calling initialize_cuda_context_rng' 30 if not __cuda_ctx_rng_initialized:
33 torch.randn(1, device=
"cuda:{}".format(i))
34 __cuda_ctx_rng_initialized =
True
def is_acceptable(tensor)