Caffe2 - C++ API
A deep learning, cross platform ML framework
init.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/types.h>
5 
6 namespace torch {
7 namespace nn {
8 namespace init {
9 
10 enum class Nonlinearity {
11  Linear,
12  Conv1D,
13  Conv2D,
14  Conv3D,
15  ConvTranspose1D,
16  ConvTranspose2D,
17  ConvTranspose3D,
18  Sigmoid,
19  Tanh,
20  ReLU,
21  LeakyReLU
22 };
23 
24 enum class FanMode { FanIn, FanOut };
25 
27 TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01);
28 
31 TORCH_API Tensor constant_(Tensor tensor, Scalar value);
32 
35 TORCH_API Tensor dirac_(Tensor tensor);
36 
39 TORCH_API Tensor eye_(Tensor matrix);
40 
44 TORCH_API Tensor normal_(Tensor tensor, double mean = 0, double std = 1);
45 
48 TORCH_API Tensor ones_(Tensor tensor);
49 
56 TORCH_API Tensor orthogonal_(Tensor tensor, double gain = 1.0);
57 
65 TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01);
66 
70 TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
71 
77 TORCH_API Tensor kaiming_normal_(
78  Tensor tensor,
79  double a = 0,
80  FanMode mode = FanMode::FanIn,
81  Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
82 
88 TORCH_API Tensor kaiming_uniform_(
89  Tensor tensor,
90  double a = 0,
91  FanMode mode = FanMode::FanIn,
92  Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
93 
98 TORCH_API Tensor xavier_normal_(Tensor tensor, double gain = 1.0);
99 
105 TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0);
106 
109 TORCH_API Tensor zeros_(Tensor tensor);
110 
111 } // namespace init
112 } // namespace nn
113 } // namespace torch
Definition: jit_type.h:17