Caffe2 - C++ API
A deep learning, cross platform ML framework
linear.cpp
1 #include <torch/nn/modules/linear.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <cmath>
7 #include <cstdint>
8 
9 namespace torch {
10 namespace nn {
11 LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {}
12 
13 LinearImpl::LinearImpl(LinearOptions options) : options(options) {
14  reset();
15 }
16 
18  weight =
19  register_parameter("weight", torch::empty({options.out_, options.in_}));
20  if (options.with_bias_) {
21  bias = register_parameter("bias", torch::empty(options.out_));
22  }
23 
24  const auto stdv = 1.0 / std::sqrt(weight.size(1));
25  NoGradGuard no_grad;
26  for (auto& p : this->parameters()) {
27  p.uniform_(-stdv, stdv);
28  }
29 }
30 
31 void LinearImpl::pretty_print(std::ostream& stream) const {
32  stream << std::boolalpha << "torch::nn::Linear(in=" << options.in_
33  << ", out=" << options.out_ << ", with_bias=" << options.with_bias_
34  << ")";
35 }
36 
38  AT_ASSERT(!options.with_bias_ || bias.defined());
39  return torch::linear(input, weight, bias);
40 }
41 } // namespace nn
42 } // namespace torch
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Definition: linear.cpp:17
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
EmbeddingOptions options
The Options used to configure this Embedding module.
Definition: embedding.h:40
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:143
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Definition: embedding.cpp:21
Definition: jit_type.h:17
void pretty_print(std::ostream &stream) const override
Pretty prints the Linear module into the given stream.
Definition: linear.cpp:31
Tensor weight
The embedding table.
Definition: embedding.h:43
Tensor forward(const Tensor &input)
Transforms the input tensor by multiplying with the weight and optionally adding the bias...
Definition: linear.cpp:37