1 #include <torch/csrc/autograd/functions/comm.h> 3 #include <torch/csrc/autograd/function.h> 4 #include <torch/csrc/autograd/functions/utils.h> 5 #include <torch/csrc/autograd/variable.h> 6 #include <torch/csrc/cuda/comm.h> 7 #include <ATen/core/functional.h> 10 #include <ATen/cuda/CUDAContext.h> 11 #include <c10/util/Optional.h> 20 std::vector<at::Device> devices,
24 bool unsqueeze_scalars)
25 : devices_(
std::move(devices)),
26 chunk_sizes_(chunk_sizes),
29 unsqueeze_scalars_(unsqueeze_scalars) {}
31 variable_list Scatter::apply(variable_list&& inputs) {
32 AT_ASSERT(inputs.size() == 1);
33 auto& input = inputs.front();
35 std::shared_ptr<Function> grad_fn;
36 if (compute_requires_grad(input)) {
38 std::make_shared<Gather>(input.device(), dim_);
39 grad_fn->set_next_edges(collect_next_edges(input));
42 auto device_indices = fmap(devices_, [](
const at::Device& device) -> int64_t {
43 return device.
index();
45 auto tensors = torch::cuda::scatter(
46 std::move(input), device_indices, chunk_sizes_, dim_, streams_);
48 std::vector<Variable> variables;
49 variables.reserve(tensors.size());
50 for (
auto& tensor : tensors) {
51 AT_ASSERT(tensor.defined());
52 if (unsqueeze_scalars_) {
53 AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
54 variables.push_back(tensor[0]);
56 variables.push_back(std::move(tensor));
60 set_history(variables, grad_fn);
65 Gather::Gather(
const at::Device& destination_device, int64_t dim)
66 : destination_device_(destination_device), dim_(dim) {}
68 variable_list Gather::apply(variable_list&& inputs) {
69 bool all_are_zero_dim =
true;
70 for (
const auto& input : inputs) {
73 "All inputs to Gather must be CUDA tensors, got ",
75 if (input.dim() > 0) {
76 all_are_zero_dim =
false;
80 const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
81 if (unsqueeze_scalars) {
83 "Was asked to gather along dimension 0, but all " 84 "input tensors were scalars; will instead unsqueeze " 85 "and return a vector.");
88 std::vector<at::Tensor> tensors;
89 tensors.reserve(inputs.size());
90 for (
auto& variable : inputs) {
91 if (unsqueeze_scalars) {
92 tensors.push_back(variable.view(1));
94 tensors.push_back(std::move(variable));
98 std::shared_ptr<Function> grad_fn;
99 if (compute_requires_grad(inputs)) {
100 std::vector<at::Device> source_devices;
101 std::vector<int64_t> input_sizes;
102 for (
auto& input : inputs) {
103 source_devices.push_back(input.device());
104 input_sizes.push_back(input.size(dim_));
106 grad_fn = std::make_shared<Scatter>(
107 std::move(source_devices),
108 std::move(input_sizes),
112 grad_fn->set_next_edges(collect_next_edges(inputs));
116 const auto destination_index =
117 destination_device_.is_cpu() ? -1 : destination_device_.index();
118 auto variable = torch::cuda::gather(tensors, dim_, destination_index);
119 set_history(variable, grad_fn);
Represents a a compute device on which a tensor is located.
DeviceIndex index() const noexcept
Returns the optional index.