1 #include <gtest/gtest.h> 10 #define TRY_CATCH_ELSE(fn, catc, els) \ 13 bool _passed = false; \ 18 } catch (std::exception & e) { \ 19 ASSERT_FALSE(_passed); \ 24 void require_equal_size_dim(
const Tensor &lhs,
const Tensor &rhs) {
25 ASSERT_EQ(lhs.dim(), rhs.dim());
26 ASSERT_TRUE(lhs.sizes().
equals(rhs.sizes()));
30 if (from_size.
size() > to_size.
size()) {
33 for (
auto from_dim_it = from_size.rbegin(); from_dim_it != from_size.rend();
35 for (
auto to_dim_it = to_size.rbegin(); to_dim_it != to_size.rend();
37 if (*from_dim_it != 1 && *from_dim_it != *to_dim_it) {
46 std::vector<std::vector<int64_t>> sizes = {{}, {0}, {1}, {1, 1}, {2}};
49 for (
auto s = sizes.begin(); s != sizes.end(); ++s) {
52 ASSERT_EQ((
size_t)t.dim(), s->size());
53 ASSERT_EQ((
size_t)t.ndimension(), s->size());
54 ASSERT_TRUE(t.sizes().
equals(*s));
55 ASSERT_EQ(t.strides().
size(), s->size());
57 std::accumulate(s->begin(), s->end(), 1, std::multiplies<int64_t>());
58 ASSERT_EQ(t.numel(), numel);
61 ASSERT_NO_THROW(ss << t << std::endl);
64 auto t2 = ones(*s, T);
66 require_equal_size_dim(t2, ones({0}, T));
69 ASSERT_EQ(t.unsqueeze(0).dim(), t.dim() + 1);
73 auto t2 = ones(*s, T);
74 auto r = t2.unsqueeze_(0);
75 ASSERT_EQ(r.dim(), t.dim() + 1);
79 if (t.dim() == 0 || t.sizes()[0] == 1) {
80 ASSERT_EQ(t.squeeze(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
84 ASSERT_EQ(t.squeeze(0).dim(), t.dim());
89 std::vector<int64_t> size_without_ones;
90 for (
auto size : *s) {
92 size_without_ones.push_back(size);
95 auto result = t.squeeze();
96 require_equal_size_dim(result, ones(size_without_ones, T));
101 auto t2 = ones(*s, T);
102 if (t2.dim() == 0 || t2.sizes()[0] == 1) {
103 ASSERT_EQ(t2.squeeze_(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
107 ASSERT_EQ(t2.squeeze_(0).dim(), t.dim());
113 auto t2 = ones(*s, T);
114 std::vector<int64_t> size_without_ones;
115 for (
auto size : *s) {
117 size_without_ones.push_back(size);
120 auto r = t2.squeeze_();
121 require_equal_size_dim(t2, ones(size_without_ones, T));
125 if (t.numel() != 0) {
126 ASSERT_EQ(t.sum(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
128 ASSERT_TRUE(t.sum(0).equal(at::zeros({}, T)));
132 if (t.numel() != 0) {
134 ASSERT_EQ(std::get<0>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
135 ASSERT_EQ(std::get<1>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
137 ASSERT_ANY_THROW(t.min(0));
141 if (t.dim() > 0 && t.numel() != 0) {
142 ASSERT_EQ(t[0].dim(), std::max<int64_t>(t.dim() - 1, 0));
144 ASSERT_ANY_THROW(t[0]);
149 t.fill_(t.sum(0)), ASSERT_GT(t.dim(), 1), ASSERT_LE(t.dim(), 1));
152 for (
auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {
153 for (
auto rhs_it = sizes.begin(); rhs_it != sizes.end(); ++rhs_it) {
156 auto lhs = ones(*lhs_it, T);
157 auto rhs = ones(*rhs_it, T);
158 if (*lhs_it != *rhs_it) {
159 ASSERT_FALSE(lhs.is_same_size(rhs));
160 ASSERT_FALSE(rhs.is_same_size(lhs));
165 {
auto lhs = ones(*lhs_it, T);
166 auto rhs = ones(*rhs_it, T);
167 lhs.resize_(*rhs_it);
168 require_equal_size_dim(lhs, rhs);
172 auto lhs = ones(*lhs_it, T);
173 auto rhs = ones(*rhs_it, T);
175 require_equal_size_dim(lhs, rhs);
181 auto lhs = ones(*lhs_it, T);
182 auto rhs = ones(*rhs_it, T);
184 require_equal_size_dim(lhs, rhs);
188 auto lhs = ones(*lhs_it, T);
189 auto rhs = ones(*rhs_it, T);
190 lhs.set_(rhs.storage());
193 ASSERT_NE(lhs.dim(), 0);
197 auto lhs = ones(*lhs_it, T);
198 auto rhs = ones(*rhs_it, T);
199 lhs.set_(rhs.storage(), rhs.storage_offset(), rhs.sizes(), rhs.strides());
200 require_equal_size_dim(lhs, rhs);
207 auto lhs = ones(*lhs_it, T);
208 auto rhs = ones(*rhs_it, T);
209 auto rhs_size = *rhs_it;
210 TRY_CATCH_ELSE(
auto result = lhs.view(rhs_size),
211 ASSERT_NE(lhs.numel(), rhs.numel()),
212 ASSERT_EQ(lhs.numel(), rhs.numel());
213 require_equal_size_dim(result, rhs););
218 auto lhs = ones(*lhs_it, T);
219 auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long);
220 TRY_CATCH_ELSE(
auto result = lhs.take(rhs), ASSERT_EQ(lhs.numel(), 0);
221 ASSERT_NE(rhs.numel(), 0),
222 require_equal_size_dim(result, rhs));
227 auto lhs = ones(*lhs_it, T);
228 auto rhs = ones(*rhs_it, T);
229 TRY_CATCH_ELSE(
auto result = lhs.ger(rhs),
231 (lhs.numel() == 0 || rhs.numel() == 0 ||
232 lhs.dim() != 1 || rhs.dim() != 1)),
234 int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0);
235 int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0);
236 require_equal_size_dim(
237 result, at::empty({dim0, dim1}, result.options()));
243 auto lhs = ones(*lhs_it, T);
244 auto lhs_size = *lhs_it;
245 auto rhs = ones(*rhs_it, T);
246 auto rhs_size = *rhs_it;
247 bool should_pass = should_expand(lhs_size, rhs_size);
248 TRY_CATCH_ELSE(
auto result = lhs.expand(rhs_size),
249 ASSERT_FALSE(should_pass),
250 ASSERT_TRUE(should_pass);
251 require_equal_size_dim(result, rhs););
257 bool should_pass_inplace = should_expand(rhs_size, lhs_size);
258 TRY_CATCH_ELSE(lhs.add_(rhs),
259 ASSERT_FALSE(should_pass_inplace),
260 ASSERT_TRUE(should_pass_inplace);
261 require_equal_size_dim(lhs, ones(*lhs_it, T)););
268 TEST(TestScalarTensor, TestScalarTensorCPU) {
273 TEST(TestScalarTensor, TestScalarTensorCUDA) {
constexpr bool equals(ArrayRef RHS) const
equals - Check for element-wise equality.
constexpr size_t size() const
size - Get the array size.
Flush-To-Zero and Denormals-Are-Zero mode.