Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/nn/init.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <ATen/ATen.h>
7 #include <c10/util/Exception.h>
8 
9 #include <algorithm>
10 #include <cmath>
11 #include <cstddef>
12 #include <tuple>
13 
14 namespace torch {
15 namespace nn {
16 namespace init {
17 namespace {
18 struct Fan {
19  explicit Fan(Tensor& tensor) {
20  const auto dimensions = tensor.ndimension();
21  AT_CHECK(
22  dimensions >= 2,
23  "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
24 
25  if (dimensions == 2) {
26  in = tensor.size(1);
27  out = tensor.size(0);
28  } else {
29  in = tensor.size(1) * tensor[0][0].numel();
30  out = tensor.size(0) * tensor[0][0].numel();
31  }
32  }
33 
34  int64_t in;
35  int64_t out;
36 };
37 
38 double calculate_kaiming_std(
39  Tensor tensor,
40  double a,
41  FanMode mode,
42  Nonlinearity nonlinearity) {
43  NoGradGuard guard;
44  Fan fan(tensor);
45  const auto gain = calculate_gain(nonlinearity, a);
46  double std = 0.0;
47  if (mode == FanMode::FanIn) {
48  std = gain / std::sqrt(fan.in);
49  } else {
50  std = gain / std::sqrt(fan.out);
51  }
52  return std;
53 }
54 } // namespace
55 
56 double calculate_gain(Nonlinearity nonlinearity, double param) {
57  if (nonlinearity == Nonlinearity::Tanh) {
58  return 5.0 / 3.0;
59  } else if (nonlinearity == Nonlinearity::ReLU) {
60  return std::sqrt(2.0);
61  } else if (nonlinearity == Nonlinearity::LeakyReLU) {
62  return std::sqrt(2.0 / (1 + pow(param, 2)));
63  }
64 
65  return 1.0;
66 }
67 
68 Tensor constant_(Tensor tensor, Scalar value) {
69  NoGradGuard guard;
70  return tensor.fill_(value);
71 }
72 
73 Tensor dirac_(Tensor tensor) {
74  NoGradGuard guard;
75 
76  AT_CHECK(
77  tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
78  "Only tensors with 3, 4, or 5 dimensions are supported");
79 
80  const auto sizes = tensor.sizes();
81  const auto min_dim = std::min(sizes[0], sizes[1]);
82 
83  tensor.zero_();
84  for (int64_t d = 0; d < min_dim; ++d) {
85  switch (tensor.ndimension()) {
86  case 3: // Temporal convolution
87  tensor[d][d][sizes[2] / 2] = 1;
88  break;
89  case 4: // Spatial convolution
90  tensor[d][d][sizes[2] / 2][sizes[3] / 2] = 1;
91  break;
92  case 5: // Volumetric convolution
93  tensor[d][d][sizes[2] / 2][sizes[3] / 2][sizes[4] / 2] = 1;
94  break;
95  }
96  }
97 
98  return tensor;
99 }
100 
101 Tensor eye_(Tensor matrix) {
102  NoGradGuard guard;
103  AT_CHECK(
104  matrix.ndimension() == 2, "Only tensors with 2 dimensions are supported");
105  return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
106 }
107 
108 Tensor normal_(Tensor tensor, double mean, double std) {
109  NoGradGuard guard;
110  return tensor.normal_(mean, std);
111 }
112 
113 Tensor ones_(Tensor tensor) {
114  NoGradGuard guard;
115  return tensor.fill_(1);
116 }
117 
118 Tensor orthogonal_(Tensor tensor, double gain) {
119  NoGradGuard guard;
120 
121  AT_CHECK(
122  tensor.ndimension() >= 2,
123  "Only tensors with 2 or more dimensions are supported");
124 
125  const auto rows = tensor.size(0);
126  const auto columns = tensor.size(1);
127  auto flattened = torch::randn({rows, columns});
128 
129  if (rows < columns) {
130  flattened.t_();
131  }
132 
133  // Compute the qr factorization
134  Tensor q, r;
135  std::tie(q, r) = torch::qr(flattened);
136  // Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
137  auto d = torch::diag(r, 0);
138  auto ph = d.sign();
139  q *= ph;
140 
141  if (rows < columns) {
142  q.t_();
143  }
144 
145  tensor.view_as(q).copy_(q);
146  tensor.mul_(gain);
147 
148  return tensor;
149 }
150 
151 Tensor sparse_(Tensor tensor, double sparsity, double std) {
152  NoGradGuard guard;
153 
154  AT_CHECK(
155  tensor.ndimension() == 2, "Only tensors with 2 dimensions are supported");
156 
157  const auto rows = tensor.size(0);
158  const auto columns = tensor.size(1);
159  const int64_t num_zeros = std::ceil(sparsity * rows);
160  tensor.normal_(0, std);
161  for (int64_t column = 0; column < columns; ++column) {
162  auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong));
163  auto zero_indices =
164  row_indices.slice(/*dim=*/0, /*start=*/0, /*end=*/num_zeros);
165  tensor.index_put_(
166  {zero_indices, torch::tensor(column, tensor.options().dtype(kLong))},
167  torch::zeros(num_zeros, tensor.options()));
168  }
169 
170  return tensor;
171 }
172 
173 Tensor uniform_(Tensor tensor, double low, double high) {
174  NoGradGuard guard;
175  return tensor.uniform_(low, high);
176 }
177 
178 Tensor kaiming_uniform_(
179  Tensor tensor,
180  double a,
181  FanMode mode,
182  Nonlinearity nonlinearity) {
183  NoGradGuard guard;
184  auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
185  // Calculate uniform bounds from standard deviation
186  const auto bound = std::sqrt(3.0) * std;
187  return tensor.uniform_(-bound, bound);
188 }
189 
190 Tensor kaiming_normal_(
191  Tensor tensor,
192  double a,
193  FanMode mode,
194  Nonlinearity nonlinearity) {
195  NoGradGuard guard;
196 
197  auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
198  return tensor.normal_(0, std);
199 }
200 
201 Tensor xavier_normal_(Tensor tensor, double gain) {
202  NoGradGuard guard;
203 
204  Fan fan(tensor);
205  const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
206  return tensor.normal_(0, std);
207 }
208 
209 Tensor xavier_uniform_(Tensor tensor, double gain) {
210  NoGradGuard guard;
211  Fan fan(tensor);
212  const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
213  // Calculate uniform bounds from standard deviation with
214  const auto a = std::sqrt(3.0) * std;
215  return tensor.uniform_(-a, a);
216 }
217 
218 Tensor zeros_(Tensor tensor) {
219  NoGradGuard guard;
220  return tensor.zero_();
221 }
222 
223 } // namespace init
224 } // namespace nn
225 } // namespace torch
Definition: jit_type.h:17