Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout.h
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/pimpl.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch {
11 namespace nn {
12 
14 struct TORCH_API DropoutOptions {
15  /* implicit */ DropoutOptions(double rate = 0.5);
19  TORCH_ARG(double, rate);
20 };
21 
22 namespace detail {
23 template <typename Derived>
24 class DropoutImplBase : public torch::nn::Cloneable<Derived> {
25  public:
26  explicit DropoutImplBase(DropoutOptions options_ = DropoutOptions());
27 
28  void reset() override;
29 
32 };
33 } // namespace detail
34 
39 class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
40  public:
42 
45  Tensor forward(const Tensor& input);
46 
48  void pretty_print(std::ostream& stream) const override;
49 };
50 
60 class TORCH_API FeatureDropoutImpl
61  : public detail::DropoutImplBase<FeatureDropoutImpl> {
62  public:
64 
67  Tensor forward(const Tensor& input);
68 
70  void pretty_print(std::ostream& stream) const override;
71 };
72 
77 TORCH_MODULE(Dropout);
78 
83 TORCH_MODULE(FeatureDropout);
84 } // namespace nn
85 } // namespace torch
DropoutOptions options
The options used to configure this Dropout module.
Definition: dropout.h:31
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17
Applies Dropout during training.
Definition: dropout.h:39
Options for Dropout and FeatureDropout.
Definition: dropout.h:14
Applies spatial Dropout to inputs with 2-D or 3-D features.
Definition: dropout.h:60