1 #include <gtest/gtest.h> 3 #include <torch/nn/module.h> 4 #include <torch/nn/modules/batchnorm.h> 5 #include <torch/nn/modules/conv.h> 6 #include <torch/nn/modules/dropout.h> 7 #include <torch/nn/modules/embedding.h> 8 #include <torch/nn/modules/functional.h> 9 #include <torch/nn/modules/linear.h> 10 #include <torch/types.h> 11 #include <torch/utils.h> 13 #include <test/cpp/api/support.h> 21 : l1(register_module(
"l1", Linear(10, 3))),
22 l2(register_module(
"l2", Linear(3, 5))),
23 l3(register_module(
"l3", Linear(5, 100))) {}
31 : param_(register_parameter(
"param", torch::empty({3, 2, 21}))),
32 l1(register_module(
"l1", Linear(5, 20))),
33 t(register_module(
"test", std::make_shared<TestModel>())) {}
37 std::shared_ptr<TestModel> t;
44 auto x = torch::randn({2, 3, 5}, torch::requires_grad());
49 ASSERT_EQ(y.ndimension(), 3);
50 ASSERT_EQ(s.ndimension(), 0);
51 for (
auto i = 0; i < 3; i++) {
52 ASSERT_EQ(y.size(i), 2);
55 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
60 auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
65 ASSERT_EQ(y.ndimension(), 4);
66 ASSERT_EQ(s.ndimension(), 0);
67 for (
auto i = 0; i < 4; i++) {
68 ASSERT_EQ(y.size(i), 2);
71 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
76 auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
81 ASSERT_EQ(y.ndimension(), 4);
82 ASSERT_EQ(s.ndimension(), 0);
83 for (
auto i = 0; i < 4; i++) {
84 ASSERT_EQ(y.size(i), 2);
87 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
92 auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
97 ASSERT_EQ(y.ndimension(), 5);
98 ASSERT_EQ(s.ndimension(), 0);
99 for (
auto i = 0; i < 5; i++) {
100 ASSERT_EQ(y.size(i), 2);
103 ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
108 auto x = torch::randn({10, 5}, torch::requires_grad());
113 ASSERT_EQ(y.ndimension(), 2);
114 ASSERT_EQ(s.ndimension(), 0);
115 ASSERT_EQ(y.size(0), 10);
116 ASSERT_EQ(y.size(1), 2);
118 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
122 auto model = std::make_shared<SimpleContainer>();
123 auto l1 = model->add(Linear(10, 3),
"l1");
124 auto l2 = model->add(Linear(3, 5),
"l2");
125 auto l3 = model->add(Linear(5, 100),
"l3");
127 auto x = torch::randn({1000, 10}, torch::requires_grad());
128 x = l1(x).clamp_min(0);
129 x = l2(x).clamp_min(0);
130 x = l3(x).clamp_min(0);
133 ASSERT_EQ(x.ndimension(), 2);
134 ASSERT_EQ(x.size(0), 1000);
135 ASSERT_EQ(x.size(1), 100);
136 ASSERT_EQ(x.min().item<
float>(), 0);
140 const int64_t dict_size = 10;
141 Embedding model(dict_size, 2);
142 ASSERT_TRUE(model->named_parameters().contains(
"weight"));
143 ASSERT_EQ(model->weight.ndimension(), 2);
144 ASSERT_EQ(model->weight.size(0), dict_size);
145 ASSERT_EQ(model->weight.size(1), 2);
149 auto x = torch::full({10}, dict_size - 1, torch::kInt64);
154 ASSERT_EQ(y.ndimension(), 2);
155 ASSERT_EQ(s.ndimension(), 0);
156 ASSERT_EQ(y.size(0), 10);
157 ASSERT_EQ(y.size(1), 2);
159 ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size);
163 Embedding model(6, 4);
164 auto x = torch::full({2, 3}, 5, torch::kInt64);
169 ASSERT_EQ(y.ndimension(), 3);
170 ASSERT_EQ(y.size(0), 2);
171 ASSERT_EQ(y.size(1), 3);
172 ASSERT_EQ(y.size(2), 4);
176 Dropout dropout(0.5);
181 ASSERT_EQ(y.ndimension(), 1);
182 ASSERT_EQ(y.size(0), 100);
183 ASSERT_LT(y.sum().item<
float>(), 130);
184 ASSERT_GT(y.sum().item<
float>(), 70);
188 ASSERT_EQ(y.sum().item<
float>(), 100);
192 auto model = std::make_shared<NestedModel>();
193 auto parameters = model->named_parameters();
194 ASSERT_EQ(parameters[
"param"].size(0), 3);
195 ASSERT_EQ(parameters[
"param"].size(1), 2);
196 ASSERT_EQ(parameters[
"param"].size(2), 21);
197 ASSERT_EQ(parameters[
"l1.bias"].size(0), 20);
198 ASSERT_EQ(parameters[
"l1.weight"].size(0), 20);
199 ASSERT_EQ(parameters[
"l1.weight"].size(1), 5);
200 ASSERT_EQ(parameters[
"test.l1.bias"].size(0), 3);
201 ASSERT_EQ(parameters[
"test.l1.weight"].size(0), 3);
202 ASSERT_EQ(parameters[
"test.l1.weight"].size(1), 10);
203 ASSERT_EQ(parameters[
"test.l2.bias"].size(0), 5);
204 ASSERT_EQ(parameters[
"test.l2.weight"].size(0), 5);
205 ASSERT_EQ(parameters[
"test.l2.weight"].size(1), 3);
206 ASSERT_EQ(parameters[
"test.l3.bias"].size(0), 100);
207 ASSERT_EQ(parameters[
"test.l3.weight"].size(0), 100);
208 ASSERT_EQ(parameters[
"test.l3.weight"].size(1), 5);
211 TEST_F(
ModulesTest, FunctionalCallsSuppliedFunction) {
212 bool was_called =
false;
213 auto functional = Functional([&was_called](
torch::Tensor input) {
217 auto output = functional(torch::ones(5, torch::requires_grad()));
218 ASSERT_TRUE(was_called);
219 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
223 output = functional(torch::ones(5, torch::requires_grad()));
224 ASSERT_TRUE(was_called);
225 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
229 auto functional = Functional(torch::relu);
230 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
231 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
232 ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
237 Functional(torch::elu, 1, 0, 1);
238 ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
245 ASSERT_TRUE(bn->options.stateful());
247 ASSERT_TRUE(bn->running_mean.defined());
248 ASSERT_EQ(bn->running_mean.dim(), 1);
249 ASSERT_EQ(bn->running_mean.size(0), 5);
251 ASSERT_TRUE(bn->running_var.defined());
252 ASSERT_EQ(bn->running_var.dim(), 1);
253 ASSERT_EQ(bn->running_var.size(0), 5);
256 ASSERT_TRUE(bn->options.affine());
258 ASSERT_TRUE(bn->weight.defined());
259 ASSERT_EQ(bn->weight.dim(), 1);
260 ASSERT_EQ(bn->weight.size(0), 5);
262 ASSERT_TRUE(bn->bias.defined());
263 ASSERT_EQ(bn->bias.dim(), 1);
264 ASSERT_EQ(bn->bias.size(0), 5);
269 ASSERT_FALSE(bn->running_mean.defined());
270 ASSERT_FALSE(bn->running_var.defined());
271 ASSERT_FALSE(bn->weight.defined());
272 ASSERT_FALSE(bn->bias.defined());
275 bn(torch::ones({2, 5})),
276 "Calling BatchNorm::forward is only permitted " 277 "when the 'stateful' option is true (was false). " 278 "Use BatchNorm::pure_forward instead.");
287 auto input = torch::randn({2, 5});
288 auto mean = torch::randn(5);
289 auto variance = torch::rand(5);
290 auto output = bn->pure_forward(input, mean, variance);
291 auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps());
292 ASSERT_TRUE(output.allclose(expected));
297 model->to(torch::kCUDA);
299 torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(
true));
304 ASSERT_EQ(y.ndimension(), 2);
305 ASSERT_EQ(s.ndimension(), 0);
306 ASSERT_EQ(y.size(0), 10);
307 ASSERT_EQ(y.size(1), 2);
309 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
314 model->to(torch::kCUDA);
315 model->to(torch::kCPU);
316 auto x = torch::randn({10, 5}, torch::requires_grad());
321 ASSERT_EQ(y.ndimension(), 2);
322 ASSERT_EQ(s.ndimension(), 0);
323 ASSERT_EQ(y.size(0), 10);
324 ASSERT_EQ(y.size(1), 2);
326 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
331 c10::str(Linear(3, 4)),
"torch::nn::Linear(in=3, out=4, with_bias=true)");
336 c10::str(Conv1d(3, 4, 5)),
337 "torch::nn::Conv1d(input_channels=3, output_channels=4, kernel_size=5, stride=1)");
339 c10::str(Conv2d(3, 4, 5)),
340 "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[1, 1])");
343 "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])");
347 c10::str(Conv2d(options)),
348 "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");
352 ASSERT_EQ(c10::str(Dropout(0.5)),
"torch::nn::Dropout(rate=0.5)");
354 c10::str(FeatureDropout(0.5)),
"torch::nn::FeatureDropout(rate=0.5)");
358 ASSERT_EQ(c10::str(Functional(torch::relu)),
"torch::nn::Functional()");
366 "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
371 c10::str(Embedding(10, 2)),
372 "torch::nn::Embedding(count=10, dimension=2)");
379 fc(register_module(
"fc", torch::nn::Linear(3, 4))),
380 table(register_module(
"table", torch::nn::Embedding(10, 2))) {}
382 torch::nn::Linear fc;
383 torch::nn::Embedding table;
389 fc(register_module(
"fc", torch::nn::Linear(4, 5))),
390 table(register_module(
"table", torch::nn::Embedding(10, 2))),
391 inner(register_module(
"inner", std::make_shared<InnerTestModule>())) {
394 torch::nn::Linear fc;
395 torch::nn::Embedding table;
396 std::shared_ptr<InnerTestModule> inner;
402 " (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n" 403 " (table): torch::nn::Embedding(count=10, dimension=2)\n" 404 " (inner): InnerTestModule(\n" 405 " (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n" 406 " (table): torch::nn::Embedding(count=10, dimension=2)\n"
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.
The base class for all modules in PyTorch.
Options for the BatchNorm module.
Options for a D-dimensional convolution module.