1 #include <torch/nn/modules/embedding.h> 3 #include <torch/types.h> 4 #include <torch/utils.h> 14 EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
15 : count_(count), dimension_(dimension) {}
17 EmbeddingImpl::EmbeddingImpl(EmbeddingOptions
options) : options(options) {
29 stream <<
"torch::nn::Embedding(count=" <<
options.count_
30 <<
", dimension=" <<
options.dimension_ <<
")";
34 return torch::embedding(weight, input);
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
void pretty_print(std::ostream &stream) const override
Pretty prints the Embedding module into the given stream.
Tensor forward(const Tensor &indices)
Performs a lookup on the embedding table stored in weight using the indices supplied and returns the ...
DropoutOptions options
The options used to configure this Dropout module.
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...