Caffe2 - C++ API
A deep learning, cross platform ML framework
weakref_test.cpp
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 
5 #include <iostream>
6 #include <chrono>
7 #include <sstream>
8 
9 using at::Tensor;
10 using at::WeakTensor;
11 
12 // Weak pointer tests
13 // gets invalidated
14 TEST(TestWeakPointer, WeakPointerGetsInvalidated) {
15  Tensor a = at::ones({2, 2});
16  WeakTensor b = a;
17  a.reset();
18  ASSERT_FALSE(b.lock().defined());
19 }
20 
21 // can successfully lock
22 TEST(TestWeakPointer, WeakPointerLock) {
23  Tensor a = at::ones({2, 2});
24  WeakTensor b = a;
25  auto c = b.lock();
26  ASSERT_TRUE(c.defined());
27 
28  a.reset();
29  ASSERT_TRUE(b.lock().defined());
30  c.reset();
31  ASSERT_FALSE(b.lock().defined());
32 }
33 
34 // updates refcounts correctly
35 TEST(TestWeakPointer, WeakUpdatesRefcountsTest) {
36  Tensor a = at::ones({2, 2});
37  ASSERT_EQ(a.use_count(), 1);
38  ASSERT_EQ(a.weak_use_count(), 1);
39  {
40  WeakTensor b = a;
41  ASSERT_EQ(a.use_count(), 1);
42  ASSERT_EQ(a.weak_use_count(), 2);
43  }
44  ASSERT_EQ(a.use_count(), 1);
45  ASSERT_EQ(a.weak_use_count(), 1);
46  {
47  WeakTensor b = a;
48  ASSERT_EQ(a.use_count(), 1);
49  auto locked = b.lock();
50  ASSERT_TRUE(locked.defined());
51  ASSERT_EQ(a.use_count(), 2);
52  }
53  ASSERT_EQ(a.use_count(), 1);
54  ASSERT_EQ(a.weak_use_count(), 1);
55  {
56  WeakTensor b = a;
57  ASSERT_EQ(a.use_count(), 1);
58  ASSERT_EQ(a.weak_use_count(), 2);
59  a.reset();
60  ASSERT_EQ(b.use_count(), 0);
61  ASSERT_EQ(b.weak_use_count(), 1);
62  }
63 }