1 #include <gtest/gtest.h> 3 #include <c10/util/tempfile.h> 5 #include <torch/nn/modules/functional.h> 6 #include <torch/nn/modules/linear.h> 7 #include <torch/nn/modules/sequential.h> 8 #include <torch/optim/optimizer.h> 9 #include <torch/optim/sgd.h> 10 #include <torch/serialize.h> 11 #include <torch/types.h> 12 #include <torch/utils.h> 14 #include <test/cpp/api/support.h> 26 Sequential xor_model() {
29 Functional(at::sigmoid),
31 Functional(at::sigmoid));
35 std::stringstream stream;
36 torch::save(input, stream);
38 torch::load(tensor, stream);
43 TEST(SerializeTest, Basic) {
44 torch::manual_seed(0);
46 auto x = torch::randn({5, 5});
47 auto y = save_and_load(x);
49 ASSERT_TRUE(y.defined());
50 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
51 ASSERT_TRUE(x.allclose(y));
54 TEST(SerializeTest, BasicToFile) {
55 torch::manual_seed(0);
57 auto x = torch::randn({5, 5});
60 torch::save(x, tempfile.name);
63 torch::load(y, tempfile.name);
65 ASSERT_TRUE(y.defined());
66 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
67 ASSERT_TRUE(x.allclose(y));
70 TEST(SerializeTest, Resized) {
71 torch::manual_seed(0);
73 auto x = torch::randn({11, 5});
75 auto y = save_and_load(x);
77 ASSERT_TRUE(y.defined());
78 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
79 ASSERT_TRUE(x.allclose(y));
82 TEST(SerializeTest, Sliced) {
83 torch::manual_seed(0);
85 auto x = torch::randn({11, 5});
87 auto y = save_and_load(x);
89 ASSERT_TRUE(y.defined());
90 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
91 ASSERT_TRUE(x.allclose(y));
94 TEST(SerializeTest, NonContiguous) {
95 torch::manual_seed(0);
97 auto x = torch::randn({11, 5});
99 auto y = save_and_load(x);
101 ASSERT_TRUE(y.defined());
102 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
103 ASSERT_TRUE(x.allclose(y));
106 TEST(SerializeTest, XOR) {
108 auto getLoss = [](Sequential model, uint32_t batch_size) {
109 auto inputs = torch::empty({batch_size, 2});
110 auto labels = torch::empty({batch_size});
111 for (
size_t i = 0; i < batch_size; i++) {
112 inputs[i] = torch::randint(2, {2}, torch::kInt64);
113 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
116 return torch::binary_cross_entropy(x, labels);
119 auto model = xor_model();
120 auto model2 = xor_model();
121 auto model3 = xor_model();
127 float running_loss = 1;
129 while (running_loss > 0.1) {
131 optimizer.zero_grad();
135 running_loss = running_loss * 0.99 + loss.sum().item<
float>() * 0.01;
136 ASSERT_LT(epoch, 3000);
141 torch::save(model, tempfile.name);
142 torch::load(model2, tempfile.name);
144 auto loss = getLoss(model2, 100);
145 ASSERT_LT(loss.item<
float>(), 0.1);
148 TEST(SerializeTest, Optim) {
149 auto model1 = Linear(5, 2);
150 auto model2 = Linear(5, 2);
151 auto model3 = Linear(5, 2);
155 torch::save(model1, model_tempfile.name);
156 torch::load(model2, model_tempfile.name);
157 torch::load(model3, model_tempfile.name);
159 auto param1 = model1->named_parameters();
160 auto param2 = model2->named_parameters();
161 auto param3 = model3->named_parameters();
162 for (
const auto& p : param1) {
163 ASSERT_TRUE(p->allclose(param2[p.key()]));
164 ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
179 auto x = torch::ones({10, 5});
183 auto y = model->forward(x).sum();
189 step(optim1, model1);
190 step(optim1, model1);
193 step(optim2, model2);
194 step(optim2_2, model2);
197 step(optim3, model3);
200 torch::save(optim3, optim_tempfile.name);
201 torch::load(optim3_2, optim_tempfile.name);
202 step(optim3_2, model3);
204 param1 = model1->named_parameters();
205 param2 = model2->named_parameters();
206 param3 = model3->named_parameters();
207 for (
const auto& p : param1) {
208 const auto& name = p.key();
211 param1[name].norm().item<float>() == param3[name].norm().item<float>());
213 param1[name].norm().item<float>() != param2[name].norm().item<float>());
217 TEST(SerializeTest, XOR_CUDA) {
218 torch::manual_seed(0);
220 auto getLoss = [](Sequential model,
222 bool is_cuda =
false) {
223 auto inputs = torch::empty({batch_size, 2});
224 auto labels = torch::empty({batch_size});
226 inputs = inputs.cuda();
227 labels = labels.cuda();
229 for (
size_t i = 0; i < batch_size; i++) {
230 inputs[i] = torch::randint(2, {2}, torch::kInt64);
231 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
234 return torch::binary_cross_entropy(x, labels);
237 auto model = xor_model();
238 auto model2 = xor_model();
239 auto model3 = xor_model();
245 float running_loss = 1;
247 while (running_loss > 0.1) {
253 running_loss = running_loss * 0.99 + loss.sum().item<
float>() * 0.01;
254 ASSERT_LT(epoch, 3000);
259 torch::save(model, tempfile.name);
260 torch::load(model2, tempfile.name);
262 auto loss = getLoss(model2, 100);
263 ASSERT_LT(loss.item<
float>(), 0.1);
265 model2->to(torch::kCUDA);
266 loss = getLoss(model2, 100,
true);
267 ASSERT_LT(loss.item<
float>(), 0.1);
270 torch::save(model2, tempfile2.name);
271 torch::load(model3, tempfile2.name);
273 loss = getLoss(model3, 100,
true);
274 ASSERT_LT(loss.item<
float>(), 0.1);
279 CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
282 register_buffer(
"foo", torch::ones(5, torch::kInt32));
288 register_module(
"b", std::make_shared<B>());
289 register_module(
"c", std::make_shared<C>());
294 register_module(
"a", std::make_shared<A>());
298 auto out = std::make_shared<M>();
299 std::stringstream ss;
300 torch::save(out, ss);
301 auto in = std::make_shared<M>();
304 const int output = in->named_buffers()[
"a.c.foo"].sum().item<
int>();
305 ASSERT_EQ(output, 5);
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.
Optimizer that defines a required step() method that takes no arguments and produces no values...
virtual void zero_grad()
Zeros out the gradients of all parameters.
does bound shape inference given a C2 net.
The base class for all modules in PyTorch.
TempFile make_tempfile(std::string name_prefix="torch-file-")
Like try_make_tempfile, but throws an exception if a temporary file could not be returned.