2 #include <gtest/gtest.h> 9 void TestEmptyTensor(
Type&
T) {
10 auto empty = randn({0}, T);
11 ASSERT_ANY_THROW(empty.expand({3}));
15 void TestOut2Basic(
Type& T) {
16 auto a = randn({3, 1}, T);
17 auto b = randn({5}, T);
18 std::vector<int64_t> expanded_sizes = {3, 5};
20 (a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
24 void TestOut2WithScalar(
Type& T) {
25 auto aScalar = ones({1}, T);
27 auto b = randn({3, 5}, T);
29 (aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
33 void TestOut2OldFallback(
Type& T) {
34 auto a = randn({3, 5}, T);
35 auto b = randn({5, 3}, T);
36 ASSERT_ANY_THROW(a + b);
40 void TestOut2MismatchedSizes(
Type& T) {
41 auto a = randn({3, 5}, T);
42 auto b = randn({7, 5}, T);
43 ASSERT_ANY_THROW(a + b);
47 void TestOut3Basic(
Type& T) {
48 auto a = randn({3, 1, 1}, T);
49 auto b = randn({1, 2, 1}, T);
50 auto c = randn({1, 1, 5}, T);
51 std::vector<int64_t> expanded_sizes = {3, 2, 5};
52 ASSERT_TRUE((a + b + c).equal(
53 a.expand(expanded_sizes) + b.expand(expanded_sizes) +
54 c.expand(expanded_sizes)));
58 void TestOut3WithScalar(
Type& T) {
59 auto aTensorScalar = ones({1}, T);
61 auto b = randn({3, 2, 1}, T);
62 auto c = randn({1, 2, 5}, T);
63 std::vector<int64_t> expanded_sizes = {3, 2, 5};
64 ASSERT_TRUE(aTensorScalar.addcmul(b, c).equal(
65 aTensorScalar.expand(expanded_sizes)
66 .addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
70 void TestOut3OldFallback(
Type& T) {
71 auto a = randn({3, 2, 5}, T);
72 auto b = randn({2, 3, 5}, T);
73 auto c = randn({5, 3, 2}, T);
74 ASSERT_ANY_THROW(a.addcmul(b, c));
78 void TestOut3MismatchedSizes(
Type& T) {
79 auto a = randn({3, 2, 5}, T);
80 auto b = randn({2, 3, 5}, T);
81 auto c = randn({5, 5, 5}, T);
82 ASSERT_ANY_THROW(a.addcmul(b, c));
86 void TestIn2Basic(
Type& T) {
87 auto a = randn({3, 5}, T);
88 auto b = randn({3, 1}, T);
89 ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
93 void TestIn2WithScalar(
Type& T) {
94 auto a = randn({3, 5}, T);
95 auto bScalar = ones({1}, T);
97 ASSERT_TRUE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
101 void TestIn2ExpandError(
Type& T) {
102 auto a = randn({1, 5}, T);
103 auto b = randn({3, 1}, T);
104 ASSERT_ANY_THROW(a.add_(b));
108 void TestIn3Basic(
Type& T) {
109 auto a = randn({3, 5, 2}, T);
110 auto b = randn({3, 1, 2}, T);
111 auto c = randn({1, 5, 1}, T);
112 auto aClone = a.clone();
113 ASSERT_TRUE(a.addcmul_(b, c).equal(
114 aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
118 void TestIn3WithScalar(
Type& T) {
119 auto a = randn({3, 5, 2}, T);
120 auto b = randn({3, 1, 2}, T);
121 auto c = randn({1, 5, 1}, T);
122 auto aClone = a.clone();
123 auto bScalar = ones({1}, T);
125 ASSERT_TRUE(a.addcmul_(bScalar, c)
126 .equal(aClone.addcmul_(
127 bScalar.expand(a.sizes()), c.expand(a.sizes()))));
131 void TestIn3ExpandError(
Type& T) {
132 auto a = randn({1, 3, 5}, T);
133 auto b = randn({4, 1, 1}, T);
134 auto c = randn({1, 3, 1}, T);
135 ASSERT_ANY_THROW(a.addcmul_(b, c));
139 void TestExplicitDimBasic(
Type& T) {
140 auto a = randn({1}, T);
141 auto b = randn({5, 3}, T);
142 auto c = randn({3, 7}, T);
143 ASSERT_TRUE(a.addmm(b, c).equal(a.expand({5, 7}).addmm(b, c)));
147 void TestExplicitDimWithScalar(
Type& T) {
148 auto a = randn({1}, T);
149 auto b = randn({5, 3}, T);
150 auto c = randn({3, 7}, T);
151 Tensor aScalar = ones({1}, T);
153 ASSERT_TRUE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
157 void TestExplicitDimWithMismatchedSizes(
Type& T) {
158 auto b = randn({5, 3}, T);
159 auto c = randn({3, 7}, T);
160 auto a = randn({3, 3}, T);
161 ASSERT_ANY_THROW(a.addmm(b, c));
164 TEST(BroadcastTest, Broadcast) {
166 Type& T = CPU(kFloat);
171 TestOut2WithScalar(T);
172 TestOut2OldFallback(T);
173 TestOut2MismatchedSizes(T);
176 TestOut3WithScalar(T);
177 TestOut3OldFallback(T);
178 TestOut3MismatchedSizes(T);
181 TestIn2WithScalar(T);
182 TestIn2ExpandError(T);
185 TestIn3WithScalar(T);
186 TestIn3ExpandError(T);
188 TestExplicitDimBasic(T);
189 TestExplicitDimWithScalar(T);
190 TestExplicitDimWithMismatchedSizes(T);
virtual TensorImpl * maybe_zero_dim(bool condition_when_zero_dim)
If condition_when_zero_dim is true, and the tensor is a 1-dim, 1-size tensor, reshape the tensor into...
Flush-To-Zero and Denormals-Are-Zero mode.