Caffe2 - C++ API
A deep learning, cross platform ML framework
test_ivalue.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include "ATen/core/ivalue.h"
5 #include "test/cpp/jit/test_base.h"
6 #include "test/cpp/jit/test_utils.h"
7 
8 namespace torch {
9 namespace jit {
10 namespace {
11 
12 using Var = SymbolicVariable;
13 
14 using namespace torch::autograd;
15 
16 void testIValue() {
17  Shared<IntList> foo = IntList::create({3, 4, 5});
18  ASSERT_EQ(foo.use_count(), 1);
19  IValue bar{foo};
20  ASSERT_EQ(foo.use_count(), 2);
21  auto baz = bar;
22  ASSERT_EQ(foo.use_count(), 3);
23  auto foo2 = std::move(bar);
24  ASSERT_EQ(foo.use_count(), 3);
25  ASSERT_TRUE(foo2.isIntList());
26  ASSERT_TRUE(bar.isNone());
27  foo2 = IValue(4.0);
28  ASSERT_TRUE(foo2.isDouble());
29  ASSERT_EQ(foo2.toDouble(), 4.0);
30  ASSERT_EQ(foo.use_count(), 2);
31  ASSERT_TRUE(ArrayRef<int64_t>(baz.toIntList()->elements()).equals({3, 4, 5}));
32 
33  auto move_it = std::move(baz).toIntList();
34  ASSERT_EQ(foo.use_count(), 2);
35  ASSERT_TRUE(baz.isNone());
36  IValue i(4);
37  ASSERT_TRUE(i.isInt());
38  ASSERT_EQ(i.toInt(), 4);
39  IValue dlist(DoubleList::create({3.5}));
40  ASSERT_TRUE(dlist.isDoubleList());
41  ASSERT_TRUE(ArrayRef<double>(std::move(dlist).toDoubleList()->elements())
42  .equals({3.5}));
43  ASSERT_TRUE(dlist.isNone());
44  dlist = IValue(DoubleList::create({3.4}));
45  ASSERT_TRUE(ArrayRef<double>(dlist.toDoubleList()->elements()).equals({3.4}));
46  IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
47  ASSERT_EQ(foo.use_count(), 3);
48  ASSERT_TRUE(the_list.isTuple());
49  auto first = std::move(the_list).toTuple()->elements().at(1);
50  ASSERT_EQ(first.toInt(), 4);
51  at::Tensor tv = at::rand({3, 4});
52  IValue ten(tv);
53  ASSERT_EQ(tv.use_count(), 2);
54  auto ten2 = ten;
55  ASSERT_EQ(tv.use_count(), 3);
56  ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor()));
57  std::move(ten2).toTensor();
58  ASSERT_EQ(tv.use_count(), 2);
59 }
60 
61 } // namespace
62 } // namespace jit
63 } // namespace torch
Definition: jit_type.h:17