1 #include <torch/nn/modules/dropout.h> 3 #include <torch/types.h> 5 #include <c10/util/Exception.h> 14 template <
typename Derived>
15 DropoutImplBase<Derived>::DropoutImplBase(DropoutOptions options_)
17 AT_CHECK(
options.rate_ >= 0,
"Dropout rate must not be less than zero");
18 AT_CHECK(
options.rate_ <= 1,
"Dropout rate must not be greater than one");
21 template <
typename Derived>
28 DropoutOptions::DropoutOptions(
double rate) : rate_(rate) {}
31 return torch::dropout(input,
options.rate_, this->is_training());
35 stream <<
"torch::nn::Dropout(rate=" <<
options.rate_ <<
")";
39 return torch::feature_dropout(input,
options.rate_, this->is_training());
43 stream <<
"torch::nn::FeatureDropout(rate=" <<
options.rate_ <<
")";
ConvOptions< D > options
The options with which this Module was constructed.
void pretty_print(std::ostream &stream) const override
Pretty prints the Dropout module into the given stream.
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor.
void pretty_print(std::ostream &stream) const override
Pretty prints the FeatureDropout module into the given stream.
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor.