3 #include "test/cpp/jit/test_utils.h" 4 #include "torch/csrc/jit/argument_spec.h" 10 int device(
const autograd::Variable& v) {
11 return v.type().is_cuda() ? v.get_device() : -1;
16 std::equal(lhs.begin(), lhs.end(), rhs.begin());
19 bool isEqual(
const CompleteArgumentInfo& ti,
const autograd::Variable& v) {
21 return ti.defined() == v.defined();
22 return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
23 ti.type() == v.scalar_type() && isEqual(ti.sizes(), v.sizes()) &&
24 isEqual(ti.strides(), v.strides());
28 return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
30 autograd::Variable undef() {
31 return autograd::Variable();
34 void testArgumentSpec() {
35 auto& CF = at::CPU(at::kFloat);
36 auto& CD = at::CPU(at::kDouble);
37 auto& GF = at::CUDA(at::kFloat);
38 auto& GD = at::CUDA(at::kDouble);
40 auto list = createStack({var(CF, {1},
true),
41 var(CD, {1, 2},
false),
43 var(GD, {4, 5, 6},
false),
47 list[1].toTensor().transpose_(0, 1);
50 auto list2 = createStack({var(CF, {1},
true),
51 var(CD, {1, 2},
false),
53 var(GD, {4, 5, 6},
false),
55 list2[1].toTensor().transpose_(0, 1);
57 CompleteArgumentSpec a(
true, list);
58 CompleteArgumentSpec b(
true, list);
59 ASSERT_EQ(a.hashCode(), b.hashCode());
62 CompleteArgumentSpec d(
true, list2);
64 ASSERT_EQ(d.hashCode(), a.hashCode());
66 for (
size_t i = 0; i < list.size(); ++i) {
67 ASSERT_TRUE(isEqual(a.at(i), list[i].toTensor()));
69 CompleteArgumentSpec no_grad(
false, list);
70 ASSERT_TRUE(no_grad != a);
72 std::unordered_set<CompleteArgumentSpec> spec;
73 spec.insert(std::move(a));
74 ASSERT_TRUE(spec.count(b) > 0);
75 ASSERT_EQ(spec.count(no_grad), 0);
76 spec.insert(std::move(no_grad));
77 ASSERT_EQ(spec.count(CompleteArgumentSpec(
true, list)), 1);
79 list2[1].toTensor().transpose_(0, 1);
80 CompleteArgumentSpec c(
true, list2);
82 ASSERT_EQ(spec.count(c), 0);
84 Stack stack = {var(CF, {1, 2},
true), 3, var(CF, {1, 2},
true)};
85 CompleteArgumentSpec with_const(
true, stack);
86 ASSERT_EQ(with_const.at(2).sizes().size(), 2);
constexpr size_t size() const
size - Get the array size.