Caffe2 - C++ API
A deep learning, cross platform ML framework
conv.h
1 #pragma once
2 
3 #include <torch/expanding_array.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <cstddef>
9 #include <vector>
10 
11 namespace torch {
12 namespace nn {
13 
15 template <size_t D>
16 struct ConvOptions {
18  int64_t input_channels,
19  int64_t output_channels,
20  ExpandingArray<D> kernel_size) :
21  input_channels_(input_channels),
22  output_channels_(output_channels),
23  kernel_size_(std::move(kernel_size)) {}
24 
27  TORCH_ARG(int64_t, input_channels);
28 
31  TORCH_ARG(int64_t, output_channels);
32 
37  TORCH_ARG(ExpandingArray<D>, kernel_size);
38 
43  TORCH_ARG(ExpandingArray<D>, stride) = 1;
44 
49  TORCH_ARG(ExpandingArray<D>, padding) = 0;
50 
55  TORCH_ARG(ExpandingArray<D>, dilation) = 1;
56 
61  TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
62 
66  TORCH_ARG(bool, transposed) = false;
67 
70  TORCH_ARG(bool, with_bias) = true;
71 
74  TORCH_ARG(int64_t, groups) = 1;
75 };
76 
78 template <size_t D, typename Derived>
79 class ConvImpl : public torch::nn::Cloneable<Derived> {
80  public:
81  ConvImpl(
82  int64_t input_channels,
83  int64_t output_channels,
84  ExpandingArray<D> kernel_size)
85  : ConvImpl(ConvOptions<D>(input_channels, output_channels, kernel_size)) {
86  }
87  explicit ConvImpl(ConvOptions<D> options);
88 
89  void reset() override;
90 
92  void pretty_print(std::ostream& stream) const override;
93 
96 
99 
102 };
103 
104 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105 
109 class TORCH_API Conv1dImpl : public ConvImpl<1, Conv1dImpl> {
110  public:
112  Tensor forward(const Tensor& input);
113 };
114 
117 
122 TORCH_MODULE(Conv1d);
123 
124 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125 
129 class TORCH_API Conv2dImpl : public ConvImpl<2, Conv2dImpl> {
130  public:
132  Tensor forward(const Tensor& input);
133 };
134 
137 
142 TORCH_MODULE(Conv2d);
143 
144 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145 
149 class TORCH_API Conv3dImpl : public ConvImpl<3, Conv3dImpl> {
150  public:
152  Tensor forward(const Tensor& input);
153 };
154 
157 
162 TORCH_MODULE(Conv3d);
163 
164 } // namespace nn
165 } // namespace torch
ConvOptions< D > options
The options with which this Module was constructed.
Definition: conv.h:95
Tensor weight
The learned kernel (or "weight").
Definition: conv.h:98
A utility class that accepts either a container of D-many values, or a single value, which is internally repeated D times.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Applies convolution over a 1-D input.
Definition: conv.h:109
Definition: jit_type.h:17
Applies convolution over a 3-D input.
Definition: conv.h:149
Tensor bias
The learned bias. Only defined if the with_bias option was true.
Definition: conv.h:101
Applies convolution over a 2-D input.
Definition: conv.h:129
Base class for all (dimension-specialized) convolution modules.
Definition: conv.h:79
TORCH_ARG(int64_t, input_channels)
The number of channels the input volumes will have.
Options for a D-dimensional convolution module.
Definition: conv.h:16