Caffe2 - C++ API
A deep learning, cross platform ML framework
broadcast_test.cpp
1 
2 #include <gtest/gtest.h>
3 
4 #include <ATen/ATen.h>
5 
6 using namespace at;
7 
8 // can't expand empty tensor
9 void TestEmptyTensor(Type& T) {
10  auto empty = randn({0}, T);
11  ASSERT_ANY_THROW(empty.expand({3}));
12 }
13 
14 // out-place function with 2 args
15 void TestOut2Basic(Type& T) {
16  auto a = randn({3, 1}, T);
17  auto b = randn({5}, T);
18  std::vector<int64_t> expanded_sizes = {3, 5};
19  ASSERT_TRUE(
20  (a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
21 }
22 
23 // with scalar
24 void TestOut2WithScalar(Type& T) {
25  auto aScalar = ones({1}, T);
26  aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
27  auto b = randn({3, 5}, T);
28  ASSERT_TRUE(
29  (aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
30 }
31 
32 // old fallback behavior yields error
33 void TestOut2OldFallback(Type& T) {
34  auto a = randn({3, 5}, T);
35  auto b = randn({5, 3}, T);
36  ASSERT_ANY_THROW(a + b);
37 }
38 
39 // with mismatched sizes
40 void TestOut2MismatchedSizes(Type& T) {
41  auto a = randn({3, 5}, T);
42  auto b = randn({7, 5}, T);
43  ASSERT_ANY_THROW(a + b);
44 }
45 
46 // out-place function with 3 args
47 void TestOut3Basic(Type& T) {
48  auto a = randn({3, 1, 1}, T);
49  auto b = randn({1, 2, 1}, T);
50  auto c = randn({1, 1, 5}, T);
51  std::vector<int64_t> expanded_sizes = {3, 2, 5};
52  ASSERT_TRUE((a + b + c).equal(
53  a.expand(expanded_sizes) + b.expand(expanded_sizes) +
54  c.expand(expanded_sizes)));
55 }
56 
57 // with scalar
58 void TestOut3WithScalar(Type& T) {
59  auto aTensorScalar = ones({1}, T);
60  aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
61  auto b = randn({3, 2, 1}, T);
62  auto c = randn({1, 2, 5}, T);
63  std::vector<int64_t> expanded_sizes = {3, 2, 5};
64  ASSERT_TRUE(aTensorScalar.addcmul(b, c).equal(
65  aTensorScalar.expand(expanded_sizes)
66  .addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
67 }
68 
69 // old fallback behavior yields error
70 void TestOut3OldFallback(Type& T) {
71  auto a = randn({3, 2, 5}, T);
72  auto b = randn({2, 3, 5}, T);
73  auto c = randn({5, 3, 2}, T);
74  ASSERT_ANY_THROW(a.addcmul(b, c));
75 }
76 
77 // with mismatched sizes
78 void TestOut3MismatchedSizes(Type& T) {
79  auto a = randn({3, 2, 5}, T);
80  auto b = randn({2, 3, 5}, T);
81  auto c = randn({5, 5, 5}, T);
82  ASSERT_ANY_THROW(a.addcmul(b, c));
83 }
84 
85 // in-place function with 2 args
86 void TestIn2Basic(Type& T) {
87  auto a = randn({3, 5}, T);
88  auto b = randn({3, 1}, T);
89  ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
90 }
91 
92 // with scalar
93 void TestIn2WithScalar(Type& T) {
94  auto a = randn({3, 5}, T);
95  auto bScalar = ones({1}, T);
96  bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
97  ASSERT_TRUE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
98 }
99 
100 // error: would have to expand inplace arg
101 void TestIn2ExpandError(Type& T) {
102  auto a = randn({1, 5}, T);
103  auto b = randn({3, 1}, T);
104  ASSERT_ANY_THROW(a.add_(b));
105 }
106 
107 // in-place function with 3 args
108 void TestIn3Basic(Type& T) {
109  auto a = randn({3, 5, 2}, T);
110  auto b = randn({3, 1, 2}, T);
111  auto c = randn({1, 5, 1}, T);
112  auto aClone = a.clone();
113  ASSERT_TRUE(a.addcmul_(b, c).equal(
114  aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
115 }
116 
117 // with scalar
118 void TestIn3WithScalar(Type& T) {
119  auto a = randn({3, 5, 2}, T);
120  auto b = randn({3, 1, 2}, T);
121  auto c = randn({1, 5, 1}, T);
122  auto aClone = a.clone();
123  auto bScalar = ones({1}, T);
124  bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
125  ASSERT_TRUE(a.addcmul_(bScalar, c)
126  .equal(aClone.addcmul_(
127  bScalar.expand(a.sizes()), c.expand(a.sizes()))));
128 }
129 
130 // error: would have to expand inplace arg
131 void TestIn3ExpandError(Type& T) {
132  auto a = randn({1, 3, 5}, T);
133  auto b = randn({4, 1, 1}, T);
134  auto c = randn({1, 3, 1}, T);
135  ASSERT_ANY_THROW(a.addcmul_(b, c));
136 }
137 
138 // explicit dim specification
139 void TestExplicitDimBasic(Type& T) {
140  auto a = randn({1}, T);
141  auto b = randn({5, 3}, T);
142  auto c = randn({3, 7}, T);
143  ASSERT_TRUE(a.addmm(b, c).equal(a.expand({5, 7}).addmm(b, c)));
144 }
145 
146 // with scalar
147 void TestExplicitDimWithScalar(Type& T) {
148  auto a = randn({1}, T);
149  auto b = randn({5, 3}, T);
150  auto c = randn({3, 7}, T);
151  Tensor aScalar = ones({1}, T);
152  aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
153  ASSERT_TRUE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
154 }
155 
156 // with mismatched sizes
157 void TestExplicitDimWithMismatchedSizes(Type& T) {
158  auto b = randn({5, 3}, T);
159  auto c = randn({3, 7}, T);
160  auto a = randn({3, 3}, T);
161  ASSERT_ANY_THROW(a.addmm(b, c));
162 }
163 
164 TEST(BroadcastTest, Broadcast) {
165  manual_seed(123);
166  Type& T = CPU(kFloat);
167 
168  TestEmptyTensor(T);
169 
170  TestOut2Basic(T);
171  TestOut2WithScalar(T);
172  TestOut2OldFallback(T);
173  TestOut2MismatchedSizes(T);
174 
175  TestOut3Basic(T);
176  TestOut3WithScalar(T);
177  TestOut3OldFallback(T);
178  TestOut3MismatchedSizes(T);
179 
180  TestIn2Basic(T);
181  TestIn2WithScalar(T);
182  TestIn2ExpandError(T);
183 
184  TestIn3Basic(T);
185  TestIn3WithScalar(T);
186  TestIn3ExpandError(T);
187 
188  TestExplicitDimBasic(T);
189  TestExplicitDimWithScalar(T);
190  TestExplicitDimWithMismatchedSizes(T);
191 }
Definition: Type.h:107
virtual TensorImpl * maybe_zero_dim(bool condition_when_zero_dim)
If condition_when_zero_dim is true, and the tensor is a 1-dim, 1-size tensor, reshape the tensor into...
Definition: TensorImpl.cpp:105
Flush-To-Zero and Denormals-Are-Zero mode.