Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout.cpp
1 #include <torch/nn/modules/dropout.h>
2 
3 #include <torch/types.h>
4 
5 #include <c10/util/Exception.h>
6 
7 #include <cstddef>
8 #include <ostream>
9 #include <vector>
10 
11 namespace torch {
12 namespace nn {
13 namespace detail {
14 template <typename Derived>
15 DropoutImplBase<Derived>::DropoutImplBase(DropoutOptions options_)
16  : options(options_) {
17  AT_CHECK(options.rate_ >= 0, "Dropout rate must not be less than zero");
18  AT_CHECK(options.rate_ <= 1, "Dropout rate must not be greater than one");
19 }
20 
21 template <typename Derived>
23 
24 template class DropoutImplBase<DropoutImpl>;
26 } // namespace detail
27 
28 DropoutOptions::DropoutOptions(double rate) : rate_(rate) {}
29 
31  return torch::dropout(input, options.rate_, this->is_training());
32 }
33 
34 void DropoutImpl::pretty_print(std::ostream& stream) const {
35  stream << "torch::nn::Dropout(rate=" << options.rate_ << ")";
36 }
37 
39  return torch::feature_dropout(input, options.rate_, this->is_training());
40 }
41 
42 void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
43  stream << "torch::nn::FeatureDropout(rate=" << options.rate_ << ")";
44 }
45 } // namespace nn
46 } // namespace torch
ConvOptions< D > options
The options with which this Module was constructed.
Definition: conv.h:95
void pretty_print(std::ostream &stream) const override
Pretty prints the Dropout module into the given stream.
Definition: dropout.cpp:34
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor.
Definition: dropout.cpp:30
Definition: jit_type.h:17
void pretty_print(std::ostream &stream) const override
Pretty prints the FeatureDropout module into the given stream.
Definition: dropout.cpp:42
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor.
Definition: dropout.cpp:38