Caffe2 - C++ API
A deep learning, cross platform ML framework
comm.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/util/Optional.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch { namespace cuda {
11 
12 using tensor_list2d = std::vector<std::vector<at::Tensor>>;
13 
14 std::vector<at::Tensor> broadcast(const at::Tensor& tensor, at::IntArrayRef devices);
15 tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices,
16  size_t buffer_size);
17 
18 std::vector<at::Tensor> scatter(
19  const at::Tensor& tensor,
20  at::IntArrayRef devices,
21  const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
22  int64_t dim = 0,
23  const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
24  c10::nullopt);
25 
26 at::Tensor gather(
27  at::TensorList tensors,
28  int64_t dim,
29  c10::optional<int32_t> destination_index);
30 }}
Definition: jit_type.h:17