Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor.cpp
1 #include <torch/csrc/autograd/functions/tensor.h>
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/functions/utils.h>
6 #include <torch/csrc/autograd/generated/Functions.h>
7 #include <torch/csrc/autograd/variable.h>
8 
9 #include <ATen/ATen.h>
10 
11 #include <cstddef>
12 #include <memory>
13 #include <stdexcept>
14 #include <utility>
15 
16 namespace torch { namespace autograd {
17 
18 auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
19  check_input_variables("CopyBackwards", grads, 1);
20  auto& grad = grads[0];
21  variable_list grad_inputs(2);
22  if (should_compute_output(0)) {
23  grad_inputs[0] = at::zeros_like(grad);
24  }
25  if (should_compute_output(1)) {
26  at::DeviceGuard device_guard(src_device);
27  // TODO: What if !grad.is_cuda(), but src_device is CUDA?
28  // This code is kind of weirdly asymmetric.
29  if (grad.is_cuda() && grad.device() != src_device) {
30  grad_inputs[1] = src_type->copy(grad);
31  } else {
32  grad_inputs[1] = grad.toType(*src_type);
33  }
34  }
35  return grad_inputs;
36 }
37 
38 CopySlices::CopySlices(
39  const Variable& base_var,
40  at::TensorGeometry view_,
41  std::shared_ptr<Function> fn_)
42  : Function(),
43  base(base_var),
44  view(std::move(view_)),
45  fn(std::move(fn_)) {
46  // Take the next_edges of fn as our own, except for index 0 which goes
47  // to base instead of the view.
48  add_input_metadata(base_var);
49  const auto num_outputs = fn->num_outputs();
50  next_edges_.reserve(num_outputs);
51  add_next_edge(base_var.gradient_edge());
52  for (size_t i = 1; i < num_outputs; i++) {
53  add_next_edge(fn->next_edge(i));
54  }
55 }
56 
57 auto CopySlices::apply(variable_list&& inputs) -> variable_list {
58  check_input_variables("CopySlices", inputs, 1);
59  auto& grad = inputs[0];
60 
61  if (!fn) {
62  throw std::runtime_error(ERR_BACKWARD_TWICE);
63  }
64 
65  auto result = at::empty_strided(base.sizes(), base.strides(), grad.options());
66  result.copy_(grad);
67 
68  auto offset = view.storage_offset() - base.storage_offset();
69  auto grad_slice = result.as_strided(view.sizes(), view.strides(), offset);
70 
71  // TODO: We clone grad_slice because we modify it below and "fn" might save
72  // it for the backward of res. We might be able to avoid the clone() if
73  // double-backprop is disabled.
74  auto res = (*fn)({ grad_slice.clone() });
75 
76  variable_list grad_inputs(num_outputs());
77  for (size_t i = 0; i < res.size(); i++) {
78  if (should_compute_output(i)) {
79  AT_ASSERT(res[i].defined());
80  if (i == 0) {
81  grad_slice.copy_(res[i]);
82  grad_inputs[i] = std::move(result); // NOLINT(bugprone-use-after-move)
83  } else {
84  grad_inputs[i] = std::move(res[i]);
85  }
86  }
87  }
88 
89  return grad_inputs;
90 }
91 
92 void CopySlices::release_variables() {
93  fn = nullptr;
94 }
95 
96 }} // namespace torch::autograd
Definition: jit_type.h:17
RAII guard that sets a certain default device in its constructor, and changes it back to the device t...
Definition: DeviceGuard.h:19