1 #include <torch/nn/modules/linear.h> 3 #include <torch/types.h> 4 #include <torch/utils.h> 11 LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {}
13 LinearImpl::LinearImpl(LinearOptions
options) : options(options) {
24 const auto stdv = 1.0 / std::sqrt(
weight.size(1));
27 p.uniform_(-stdv, stdv);
32 stream << std::boolalpha <<
"torch::nn::Linear(in=" <<
options.in_
33 <<
", out=" <<
options.out_ <<
", with_bias=" <<
options.with_bias_
38 AT_ASSERT(!
options.with_bias_ || bias.defined());
39 return torch::linear(input,
weight, bias);
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
EmbeddingOptions options
The Options used to configure this Embedding module.
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
void pretty_print(std::ostream &stream) const override
Pretty prints the Linear module into the given stream.
Tensor weight
The embedding table.
Tensor forward(const Tensor &input)
Transforms the input tensor by multiplying with the weight and optionally adding the bias...