1 #include <torch/nn/modules/dropout.h>     3 #include <torch/types.h>     5 #include <c10/util/Exception.h>    14 template <
typename Derived>
    15 DropoutImplBase<Derived>::DropoutImplBase(DropoutOptions options_)
    17   AT_CHECK(
options.rate_ >= 0, 
"Dropout rate must not be less than zero");
    18   AT_CHECK(
options.rate_ <= 1, 
"Dropout rate must not be greater than one");
    21 template <
typename Derived>
    28 DropoutOptions::DropoutOptions(
double rate) : rate_(rate) {}
    31   return torch::dropout(input, 
options.rate_, this->is_training());
    35   stream << 
"torch::nn::Dropout(rate=" << 
options.rate_ << 
")";
    39   return torch::feature_dropout(input, 
options.rate_, this->is_training());
    43   stream << 
"torch::nn::FeatureDropout(rate=" << 
options.rate_ << 
")";
 
ConvOptions< D > options
The options with which this Module was constructed. 
 
void pretty_print(std::ostream &stream) const override
Pretty prints the Dropout module into the given stream. 
 
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor. 
 
void pretty_print(std::ostream &stream) const override
Pretty prints the FeatureDropout module into the given stream. 
 
Tensor forward(const Tensor &input)
During training, applies a noise mask to the input tensor.