Caffe2 - C++ API
A deep learning, cross platform ML framework
comm.h
1 #pragma once
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/WindowsTorchApiMacro.h>
6 
7 #include <ATen/ATen.h>
8 #include <ATen/cuda/CUDAContext.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 namespace torch {
14 namespace autograd {
15 
16 struct TORCH_API Scatter : public Function {
17  explicit Scatter(
18  std::vector<at::Device> devices,
19  const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
20  int64_t dim = 0,
21  const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams =
22  c10::nullopt,
23  bool unsqueeze_scalars = false);
24 
25  variable_list apply(variable_list&& inputs) override;
26 
27  std::vector<at::Device> devices_;
29  int64_t dim_;
31  bool unsqueeze_scalars_;
32 };
33 
34 struct TORCH_API Gather : public Function {
35  explicit Gather(const at::Device& destination_device, int64_t dim = 0);
36 
37  variable_list apply(variable_list&& inputs) override;
38 
39  at::Device destination_device_;
40  int64_t dim_;
41 };
42 
43 } // namespace autograd
44 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: jit_type.h:17