3 #include <torch/csrc/WindowsTorchApiMacro.h> 4 #include <torch/types.h> 10 enum class Nonlinearity {
24 enum class FanMode { FanIn, FanOut };
27 TORCH_API
double calculate_gain(Nonlinearity nonlinearity,
double param = 0.01);
44 TORCH_API
Tensor normal_(
Tensor tensor,
double mean = 0,
double std = 1);
56 TORCH_API
Tensor orthogonal_(
Tensor tensor,
double gain = 1.0);
65 TORCH_API
Tensor sparse_(
Tensor tensor,
double sparsity,
double std = 0.01);
70 TORCH_API
Tensor uniform_(
Tensor tensor,
double low = 0,
double high = 1);
77 TORCH_API
Tensor kaiming_normal_(
80 FanMode mode = FanMode::FanIn,
81 Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
88 TORCH_API
Tensor kaiming_uniform_(
91 FanMode mode = FanMode::FanIn,
92 Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
98 TORCH_API
Tensor xavier_normal_(
Tensor tensor,
double gain = 1.0);
105 TORCH_API
Tensor xavier_uniform_(
Tensor tensor,
double gain = 1.0);