Caffe2 - C++ API
A deep learning, cross platform ML framework
test_argument_spec.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_utils.h"
4 #include "torch/csrc/jit/argument_spec.h"
5 
6 namespace torch {
7 namespace jit {
8 namespace test {
9 
10 int device(const autograd::Variable& v) {
11  return v.type().is_cuda() ? v.get_device() : -1;
12 }
13 
14 bool isEqual(at::IntArrayRef lhs, at::IntArrayRef rhs) {
15  return lhs.size() == rhs.size() &&
16  std::equal(lhs.begin(), lhs.end(), rhs.begin());
17 }
18 
19 bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) {
20  if (!ti.defined())
21  return ti.defined() == v.defined();
22  return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
23  ti.type() == v.scalar_type() && isEqual(ti.sizes(), v.sizes()) &&
24  isEqual(ti.strides(), v.strides());
25 }
26 
27 autograd::Variable var(at::Type& t, at::IntArrayRef sizes, bool requires_grad) {
28  return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
29 }
30 autograd::Variable undef() {
31  return autograd::Variable();
32 }
33 
34 void testArgumentSpec() {
35  auto& CF = at::CPU(at::kFloat);
36  auto& CD = at::CPU(at::kDouble);
37  auto& GF = at::CUDA(at::kFloat);
38  auto& GD = at::CUDA(at::kDouble);
39 
40  auto list = createStack({var(CF, {1}, true),
41  var(CD, {1, 2}, false),
42  var(GF, {}, true),
43  var(GD, {4, 5, 6}, false),
44  undef()});
45 
46  // make sure we have some non-standard strides
47  list[1].toTensor().transpose_(0, 1);
48 
49  // same list but different backing values
50  auto list2 = createStack({var(CF, {1}, true),
51  var(CD, {1, 2}, false),
52  var(GF, {}, true),
53  var(GD, {4, 5, 6}, false),
54  undef()});
55  list2[1].toTensor().transpose_(0, 1);
56 
57  CompleteArgumentSpec a(true, list);
58  CompleteArgumentSpec b(true, list);
59  ASSERT_EQ(a.hashCode(), b.hashCode());
60 
61  ASSERT_EQ(a, b);
62  CompleteArgumentSpec d(true, list2);
63  ASSERT_EQ(d, a);
64  ASSERT_EQ(d.hashCode(), a.hashCode());
65 
66  for (size_t i = 0; i < list.size(); ++i) {
67  ASSERT_TRUE(isEqual(a.at(i), list[i].toTensor()));
68  }
69  CompleteArgumentSpec no_grad(/*with_grad=*/false, list);
70  ASSERT_TRUE(no_grad != a);
71 
72  std::unordered_set<CompleteArgumentSpec> spec;
73  spec.insert(std::move(a));
74  ASSERT_TRUE(spec.count(b) > 0);
75  ASSERT_EQ(spec.count(no_grad), 0);
76  spec.insert(std::move(no_grad));
77  ASSERT_EQ(spec.count(CompleteArgumentSpec(true, list)), 1);
78 
79  list2[1].toTensor().transpose_(0, 1);
80  CompleteArgumentSpec c(true, list2); // same as list, except for one stride
81  ASSERT_FALSE(c == a);
82  ASSERT_EQ(spec.count(c), 0);
83 
84  Stack stack = {var(CF, {1, 2}, true), 3, var(CF, {1, 2}, true)};
85  CompleteArgumentSpec with_const(true, stack);
86  ASSERT_EQ(with_const.at(2).sizes().size(), 2);
87 }
88 
89 } // namespace test
90 } // namespace jit
91 } // namespace torch
Definition: module.cpp:17
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17