1 #include <torch/nn/init.h> 3 #include <torch/types.h> 4 #include <torch/utils.h> 7 #include <c10/util/Exception.h> 19 explicit Fan(
Tensor& tensor) {
20 const auto dimensions = tensor.ndimension();
23 "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
25 if (dimensions == 2) {
29 in = tensor.size(1) * tensor[0][0].numel();
30 out = tensor.size(0) * tensor[0][0].numel();
38 double calculate_kaiming_std(
42 Nonlinearity nonlinearity) {
45 const auto gain = calculate_gain(nonlinearity, a);
47 if (mode == FanMode::FanIn) {
48 std = gain / std::sqrt(fan.in);
50 std = gain / std::sqrt(fan.out);
56 double calculate_gain(Nonlinearity nonlinearity,
double param) {
57 if (nonlinearity == Nonlinearity::Tanh) {
59 }
else if (nonlinearity == Nonlinearity::ReLU) {
60 return std::sqrt(2.0);
61 }
else if (nonlinearity == Nonlinearity::LeakyReLU) {
62 return std::sqrt(2.0 / (1 + pow(param, 2)));
70 return tensor.fill_(value);
77 tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
78 "Only tensors with 3, 4, or 5 dimensions are supported");
80 const auto sizes = tensor.sizes();
81 const auto min_dim = std::min(sizes[0], sizes[1]);
84 for (int64_t d = 0; d < min_dim; ++d) {
85 switch (tensor.ndimension()) {
87 tensor[d][d][sizes[2] / 2] = 1;
90 tensor[d][d][sizes[2] / 2][sizes[3] / 2] = 1;
93 tensor[d][d][sizes[2] / 2][sizes[3] / 2][sizes[4] / 2] = 1;
104 matrix.ndimension() == 2,
"Only tensors with 2 dimensions are supported");
105 return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
110 return tensor.normal_(mean, std);
115 return tensor.fill_(1);
122 tensor.ndimension() >= 2,
123 "Only tensors with 2 or more dimensions are supported");
125 const auto rows = tensor.size(0);
126 const auto columns = tensor.size(1);
127 auto flattened = torch::randn({rows, columns});
129 if (rows < columns) {
135 std::tie(q, r) = torch::qr(flattened);
137 auto d = torch::diag(r, 0);
141 if (rows < columns) {
145 tensor.view_as(q).copy_(q);
151 Tensor sparse_(
Tensor tensor,
double sparsity,
double std) {
155 tensor.ndimension() == 2,
"Only tensors with 2 dimensions are supported");
157 const auto rows = tensor.size(0);
158 const auto columns = tensor.size(1);
159 const int64_t num_zeros = std::ceil(sparsity * rows);
160 tensor.normal_(0, std);
161 for (int64_t column = 0; column < columns; ++column) {
162 auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong));
164 row_indices.slice(0, 0, num_zeros);
166 {zero_indices, torch::tensor(column, tensor.options().dtype(kLong))},
167 torch::zeros(num_zeros, tensor.options()));
173 Tensor uniform_(
Tensor tensor,
double low,
double high) {
175 return tensor.uniform_(low, high);
182 Nonlinearity nonlinearity) {
184 auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
186 const auto bound = std::sqrt(3.0) * std;
187 return tensor.uniform_(-bound, bound);
194 Nonlinearity nonlinearity) {
197 auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
198 return tensor.normal_(0, std);
205 const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
206 return tensor.normal_(0, std);
212 const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
214 const auto a = std::sqrt(3.0) * std;
215 return tensor.uniform_(-a, a);
220 return tensor.zero_();