Caffe2 - C++ API
A deep learning, cross platform ML framework
nccl.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <THC/THC.h>
6 #include <c10/util/Optional.h>
7 
8 #include <nccl.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 namespace torch {
14 namespace cuda {
15 namespace nccl {
16 
17 // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
18 // Don't use them outside of these files.
19 namespace detail {
20 
21 void throw_nccl_error(ncclResult_t status);
22 
23 static inline void NCCL_CHECK(ncclResult_t status) {
24  if (status != ncclSuccess) {
25  throw_nccl_error(status);
26  }
27 }
28 
29 struct AutoNcclGroup {
30  AutoNcclGroup() {
31 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
32  NCCL_CHECK(ncclGroupStart());
33 #endif
34  }
35  ~AutoNcclGroup() {
36 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
37  NCCL_CHECK(ncclGroupEnd());
38 #endif
39  }
40 };
41 
42 at::ArrayRef<ncclComm_t> _get_communicators(at::TensorList inputs);
43 void _check_inputs(
44  at::TensorList inputs,
45  at::TensorList outputs,
46  int input_multiplier,
47  int output_multiplier);
48 ncclDataType_t _get_data_type(const at::Tensor& t);
49 
50 } // namespace detail
51 
52 using comm_list = std::vector<ncclComm_t>;
53 using stream_list = std::vector<c10::optional<at::cuda::CUDAStream>>;
54 
55 std::uint64_t version();
56 
57 bool is_available(at::TensorList tensors);
58 
59 void broadcast(
60  at::TensorList tensors,
61  const stream_list& streams = {},
62  const comm_list& user_comms = {});
63 
64 size_t get_max_count();
65 
66 void reduce(
67  const std::vector<at::Tensor>& inputs,
68  std::vector<at::Tensor>& outputs,
69  int32_t root = 0,
70  int32_t op = ncclSum,
71  const stream_list& streams = {},
72  const comm_list& user_comms = {});
73 
74 void reduce(
75  std::vector<at::Tensor>& inputs,
76  int32_t root = 0,
77  int32_t op = ncclSum,
78  const stream_list& streams = {},
79  const comm_list& user_comms = {});
80 
81 } // namespace nccl
82 } // namespace cuda
83 } // namespace torch
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41