Caffe2 - C++ API
A deep learning, cross platform ML framework
functional.h
1 #pragma once
2 
3 #include <torch/csrc/utils/variadic.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 #include <torch/csrc/WindowsTorchApiMacro.h>
8 
9 #include <functional>
10 #include <utility>
11 
12 namespace torch {
13 namespace nn {
14 
58 class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
59  public:
60  using Function = std::function<Tensor(Tensor)>;
61 
63  explicit FunctionalImpl(Function function);
64 
65  template <
66  typename SomeFunction,
67  typename... Args,
68  typename = torch::enable_if_t<(sizeof...(Args) > 0)>>
69  explicit FunctionalImpl(SomeFunction original_function, Args&&... args)
70  : function_(std::bind(
71  original_function,
72  /*input=*/std::placeholders::_1,
73  std::forward<Args>(args)...)) {
74  // std::bind is normally evil, but (1) gcc is broken w.r.t. handling
75  // parameter pack expansion in lambdas and (2) moving parameter packs into
76  // a lambda only works with C++14, so std::bind is the more move-aware
77  // solution here.
78  }
79 
80  void reset() override;
81 
83  void pretty_print(std::ostream& stream) const override;
84 
86  Tensor forward(Tensor input);
87 
89  Tensor operator()(Tensor input);
90 
91  private:
92  Function function_;
93 };
94 
99 TORCH_MODULE(Functional);
100 
101 } // namespace nn
102 } // namespace torch
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17