Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding.h
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/pimpl.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch {
11 namespace nn {
12 
14 struct TORCH_API EmbeddingOptions {
15  EmbeddingOptions(int64_t count, int64_t dimension);
17  TORCH_ARG(int64_t, count);
19  TORCH_ARG(int64_t, dimension);
20 };
21 
23 class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
24  public:
25  EmbeddingImpl(int64_t count, int64_t dimension)
26  : EmbeddingImpl(EmbeddingOptions(count, dimension)) {}
27  explicit EmbeddingImpl(EmbeddingOptions options);
28 
29  void reset() override;
30 
32  void pretty_print(std::ostream& stream) const override;
33 
36  Tensor forward(const Tensor& indices);
37 
41 
44 };
45 
50 TORCH_MODULE(Embedding);
51 
52 } // namespace nn
53 } // namespace torch
EmbeddingOptions options
The Options used to configure this Embedding module.
Definition: embedding.h:40
Options for the Embedding module.
Definition: embedding.h:14
Performs a lookup in a fixed size embedding table.
Definition: embedding.h:23
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
Definition: jit_type.h:17
Tensor weight
The embedding table.
Definition: embedding.h:43