1 #include <torch/nn/modules/conv.h> 3 #include <torch/expanding_array.h> 4 #include <torch/types.h> 5 #include <torch/utils.h> 15 template <
size_t D,
typename Derived>
16 ConvImpl<D, Derived>::ConvImpl(ConvOptions<D> options)
17 : options(
std::move(options)) {
21 template <
size_t D,
typename Derived>
24 for (
auto pad : *
options.output_padding_) {
26 pad == 0,
"Only transposed convolutions support output padding!");
30 std::vector<int64_t> weights_size;
32 weights_size.push_back(
options.input_channels_);
33 weights_size.push_back(
options.output_channels_ /
options.groups_);
35 weights_size.push_back(
options.output_channels_);
42 AT_ASSERT(weights_size.size() == 2 +
options.kernel_size_->size());
47 "bias", torch::empty(
options.output_channels_));
50 const auto number_of_features = std::accumulate(
54 std::multiplies<int64_t>{});
55 const auto stdv = 1.0 / std::sqrt(number_of_features);
58 p.uniform_(-stdv, stdv);
62 template <
size_t D,
typename Derived>
64 stream <<
"torch::nn::Conv" <<
D <<
"d" 65 <<
"(input_channels=" <<
options.input_channels_
66 <<
", output_channels=" <<
options.output_channels_
67 <<
", kernel_size=" <<
options.kernel_size_
68 <<
", stride=" <<
options.stride_ <<
")";
73 return torch::conv_transpose1d(
95 return torch::conv_transpose2d(
105 return torch::conv2d(
117 return torch::conv_transpose3d(
127 return torch::conv3d(
Tensor bias
The learned bias.
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
void pretty_print(std::ostream &stream) const override
Pretty prints the Conv{1,2,3}d module into the given stream.
BatchNormOptions options
The options with which this module was constructed.
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Tensor weight
The learned weight.
Options for a D-dimensional convolution module.
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...