Caffe2 - C++ API
A deep learning, cross platform ML framework
conv.cpp
1 #include <torch/nn/modules/conv.h>
2 
3 #include <torch/expanding_array.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <cmath>
8 #include <cstdint>
9 #include <functional>
10 #include <utility>
11 #include <vector>
12 
13 namespace torch {
14 namespace nn {
15 template <size_t D, typename Derived>
16 ConvImpl<D, Derived>::ConvImpl(ConvOptions<D> options)
17  : options(std::move(options)) {
18  reset();
19 }
20 
21 template <size_t D, typename Derived>
23  if (!options.transposed_) {
24  for (auto pad : *options.output_padding_) {
25  AT_CHECK(
26  pad == 0, "Only transposed convolutions support output padding!");
27  }
28  }
29 
30  std::vector<int64_t> weights_size;
31  if (options.transposed_) {
32  weights_size.push_back(options.input_channels_);
33  weights_size.push_back(options.output_channels_ / options.groups_);
34  } else {
35  weights_size.push_back(options.output_channels_);
36  weights_size.push_back(options.input_channels_ / options.groups_);
37  }
38  weights_size.insert(
39  weights_size.end(),
40  options.kernel_size_->begin(),
41  options.kernel_size_->end());
42  AT_ASSERT(weights_size.size() == 2 + options.kernel_size_->size());
43 
44  weight = this->register_parameter("weight", torch::empty(weights_size));
45  if (options.with_bias_) {
46  bias = this->register_parameter(
47  "bias", torch::empty(options.output_channels_));
48  }
49 
50  const auto number_of_features = std::accumulate(
51  options.kernel_size_->begin(),
52  options.kernel_size_->end(),
53  options.input_channels_,
54  std::multiplies<int64_t>{});
55  const auto stdv = 1.0 / std::sqrt(number_of_features);
56  NoGradGuard no_grad;
57  for (auto& p : this->parameters()) {
58  p.uniform_(-stdv, stdv);
59  }
60 }
61 
62 template <size_t D, typename Derived>
63 void ConvImpl<D, Derived>::pretty_print(std::ostream& stream) const {
64  stream << "torch::nn::Conv" << D << "d"
65  << "(input_channels=" << options.input_channels_
66  << ", output_channels=" << options.output_channels_
67  << ", kernel_size=" << options.kernel_size_
68  << ", stride=" << options.stride_ << ")";
69 }
70 
71 Tensor Conv1dImpl::forward(const Tensor& input) {
72  if (options.transposed_) {
73  return torch::conv_transpose1d(
74  input,
75  weight,
76  bias,
77  options.stride_,
78  options.padding_,
79  options.output_padding_,
80  options.groups_,
81  options.dilation_);
82  }
83  return torch::conv1d(
84  input,
85  weight,
86  bias,
87  options.stride_,
88  options.padding_,
89  options.dilation_,
90  options.groups_);
91 }
92 
93 Tensor Conv2dImpl::forward(const Tensor& input) {
94  if (options.transposed_) {
95  return torch::conv_transpose2d(
96  input,
97  weight,
98  bias,
99  options.stride_,
100  options.padding_,
101  options.output_padding_,
102  options.groups_,
103  options.dilation_);
104  }
105  return torch::conv2d(
106  input,
107  weight,
108  bias,
109  options.stride_,
110  options.padding_,
111  options.dilation_,
112  options.groups_);
113 }
114 
115 Tensor Conv3dImpl::forward(const Tensor& input) {
116  if (options.transposed_) {
117  return torch::conv_transpose3d(
118  input,
119  weight,
120  bias,
121  options.stride_,
122  options.padding_,
123  options.output_padding_,
124  options.groups_,
125  options.dilation_);
126  } else {
127  return torch::conv3d(
128  input,
129  weight,
130  bias,
131  options.stride_,
132  options.padding_,
133  options.dilation_,
134  options.groups_);
135  }
136 }
137 
138 template struct ConvOptions<1>;
139 template class ConvImpl<1, Conv1dImpl>;
140 
141 template struct ConvOptions<2>;
142 template class ConvImpl<2, Conv2dImpl>;
143 
144 template struct ConvOptions<3>;
145 template class ConvImpl<3, Conv3dImpl>;
146 
147 } // namespace nn
148 } // namespace torch
Tensor bias
The learned bias.
Definition: batchnorm.h:83
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:143
void pretty_print(std::ostream &stream) const override
Pretty prints the Conv{1,2,3}d module into the given stream.
Definition: conv.cpp:63
BatchNormOptions options
The options with which this module was constructed.
Definition: batchnorm.h:75
Definition: jit_type.h:17
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Definition: batchnorm.cpp:21
Definition: static.cpp:70
Tensor weight
The learned weight.
Definition: batchnorm.h:79
Options for a D-dimensional convolution module.
Definition: conv.h:16
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Definition: conv.cpp:22