3 #include <torch/csrc/autograd/function.h> 4 #include <torch/csrc/autograd/variable.h> 5 #include <torch/csrc/WindowsTorchApiMacro.h> 8 #include <ATen/cuda/CUDAContext.h> 16 struct TORCH_API
Scatter :
public Function {
18 std::vector<at::Device> devices,
19 const c10::optional<std::vector<int64_t>>& chunk_sizes = c10::nullopt,
23 bool unsqueeze_scalars =
false);
25 variable_list apply(variable_list&& inputs)
override;
27 std::vector<at::Device> devices_;
31 bool unsqueeze_scalars_;
34 struct TORCH_API
Gather :
public Function {
37 variable_list apply(variable_list&& inputs)
override;
Represents a a compute device on which a tensor is located.