Caffe2 - C++ API
A deep learning, cross platform ML framework
linear.h
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <cstddef>
9 #include <vector>
10 
11 namespace torch {
12 namespace nn {
14 struct TORCH_API LinearOptions {
15  LinearOptions(int64_t in, int64_t out);
17  TORCH_ARG(int64_t, in);
19  TORCH_ARG(int64_t, out);
21  TORCH_ARG(bool, with_bias) = true;
22 };
23 
25 class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
26  public:
27  LinearImpl(int64_t in, int64_t out) : LinearImpl(LinearOptions(in, out)) {}
28  explicit LinearImpl(LinearOptions options);
29 
30  void reset() override;
31 
33  void pretty_print(std::ostream& stream) const override;
34 
37  Tensor forward(const Tensor& input);
38 
41 
44 
48 };
49 
54 TORCH_MODULE(Linear);
55 
56 } // namespace nn
57 } // namespace torch
Options for the Linear module.
Definition: linear.h:14
Applies a linear transformation with optional bias.
Definition: linear.h:25
Tensor bias
The learned bias.
Definition: linear.h:47
Tensor weight
The learned weight.
Definition: linear.h:43
LinearOptions options
The options used to configure this module.
Definition: linear.h:40
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17