Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding.cpp
1 #include <torch/nn/modules/embedding.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <cstddef>
7 #include <ostream>
8 #include <utility>
9 #include <vector>
10 
11 namespace torch {
12 namespace nn {
13 
14 EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
15  : count_(count), dimension_(dimension) {}
16 
17 EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) {
18  reset();
19 }
20 
22  weight = register_parameter(
23  "weight", torch::empty({options.count_, options.dimension_}));
24  NoGradGuard guard;
25  weight.normal_(0, 1);
26 }
27 
28 void EmbeddingImpl::pretty_print(std::ostream& stream) const {
29  stream << "torch::nn::Embedding(count=" << options.count_
30  << ", dimension=" << options.dimension_ << ")";
31 }
32 
34  return torch::embedding(weight, /*indices=*/input);
35 }
36 } // namespace nn
37 } // namespace torch
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
void pretty_print(std::ostream &stream) const override
Pretty prints the Embedding module into the given stream.
Definition: embedding.cpp:28
Tensor forward(const Tensor &indices)
Performs a lookup on the embedding table stored in weight using the indices supplied and returns the ...
Definition: embedding.cpp:33
DropoutOptions options
The options used to configure this Dropout module.
Definition: dropout.h:31
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
Definition: embedding.cpp:21
Definition: jit_type.h:17