Caffe2 - C++ API
A deep learning, cross platform ML framework
comm.cpp
1 #include <torch/csrc/autograd/functions/comm.h>
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/functions/utils.h>
5 #include <torch/csrc/autograd/variable.h>
6 #include <torch/csrc/cuda/comm.h>
7 #include <ATen/core/functional.h>
8 
9 #include <ATen/ATen.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <c10/util/Optional.h>
12 
13 #include <cstddef>
14 #include <memory>
15 #include <vector>
16 
17 namespace torch {
18 namespace autograd {
19 Scatter::Scatter(
20  std::vector<at::Device> devices,
21  const c10::optional<std::vector<int64_t>>& chunk_sizes,
22  int64_t dim,
23  const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams,
24  bool unsqueeze_scalars)
25  : devices_(std::move(devices)),
26  chunk_sizes_(chunk_sizes),
27  dim_(dim),
28  streams_(streams),
29  unsqueeze_scalars_(unsqueeze_scalars) {}
30 
31 variable_list Scatter::apply(variable_list&& inputs) {
32  AT_ASSERT(inputs.size() == 1);
33  auto& input = inputs.front();
34 
35  std::shared_ptr<Function> grad_fn;
36  if (compute_requires_grad(input)) {
37  grad_fn =
38  std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
39  grad_fn->set_next_edges(collect_next_edges(input));
40  }
41 
42  auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
43  return device.index();
44  });
45  auto tensors = torch::cuda::scatter(
46  std::move(input), device_indices, chunk_sizes_, dim_, streams_);
47 
48  std::vector<Variable> variables;
49  variables.reserve(tensors.size());
50  for (auto& tensor : tensors) {
51  AT_ASSERT(tensor.defined());
52  if (unsqueeze_scalars_) {
53  AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
54  variables.push_back(tensor[0]);
55  } else {
56  variables.push_back(std::move(tensor));
57  }
58  }
59 
60  set_history(variables, grad_fn);
61 
62  return variables;
63 }
64 
65 Gather::Gather(const at::Device& destination_device, int64_t dim)
66  : destination_device_(destination_device), dim_(dim) {}
67 
68 variable_list Gather::apply(variable_list&& inputs) {
69  bool all_are_zero_dim = true;
70  for (const auto& input : inputs) {
71  AT_CHECK(
72  input.is_cuda(),
73  "All inputs to Gather must be CUDA tensors, got ",
74  input.type());
75  if (input.dim() > 0) {
76  all_are_zero_dim = false;
77  }
78  }
79 
80  const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
81  if (unsqueeze_scalars) {
82  AT_WARN(
83  "Was asked to gather along dimension 0, but all "
84  "input tensors were scalars; will instead unsqueeze "
85  "and return a vector.");
86  }
87 
88  std::vector<at::Tensor> tensors;
89  tensors.reserve(inputs.size());
90  for (auto& variable : inputs) {
91  if (unsqueeze_scalars) {
92  tensors.push_back(variable.view(1));
93  } else {
94  tensors.push_back(std::move(variable));
95  }
96  }
97 
98  std::shared_ptr<Function> grad_fn;
99  if (compute_requires_grad(inputs)) {
100  std::vector<at::Device> source_devices;
101  std::vector<int64_t> input_sizes;
102  for (auto& input : inputs) {
103  source_devices.push_back(input.device());
104  input_sizes.push_back(input.size(dim_));
105  }
106  grad_fn = std::make_shared<Scatter>(
107  std::move(source_devices),
108  std::move(input_sizes),
109  dim_,
110  /*streams=*/c10::nullopt,
111  /*unsqueeze_scalars=*/unsqueeze_scalars);
112  grad_fn->set_next_edges(collect_next_edges(inputs));
113  }
114 
115  // This is special logic for torch::cuda::gather!
116  const auto destination_index =
117  destination_device_.is_cpu() ? -1 : destination_device_.index();
118  auto variable = torch::cuda::gather(tensors, dim_, destination_index);
119  set_history(variable, grad_fn);
120  return {variable};
121 }
122 
123 } // namespace autograd
124 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: jit_type.h:17
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70