1 #include <gtest/gtest.h> 6 #define _USE_MATH_DEFINES 10 #include <ATen/Dispatch.h> 15 constexpr
auto Float = ScalarType::Float;
17 template<
typename scalar_type>
22 ss <<
"hello, dispatch: " << a.type().toString() << s <<
"\n";
23 auto data = (scalar_type*)a.data_ptr();
32 void test_overflow() {
34 ASSERT_EQ(s1.toFloat(),
static_cast<float>(M_PI));
38 ASSERT_EQ(s1.toFloat(), 100000.0);
39 ASSERT_EQ(s1.toInt(), 100000);
41 ASSERT_THROW(s1.toHalf(), std::domain_error);
44 ASSERT_TRUE(std::isnan(s1.toFloat()));
45 ASSERT_THROW(s1.toInt(), std::domain_error);
48 ASSERT_TRUE(std::isinf(s1.toFloat()));
49 ASSERT_THROW(s1.toInt(), std::domain_error);
52 TEST(TestScalar, TestScalar) {
57 Half h = bar.toHalf();
59 cout <<
"H2: " << h2.toDouble() <<
" " << what.toFloat() <<
" " 60 << bar.toDouble() <<
" " << what.isIntegral() <<
"\n";
61 Generator& gen = at::globalContext().defaultGenerator(at::kCPU);
62 ASSERT_NO_THROW(gen.seed());
63 auto&&
C = at::globalContext();
65 auto t2 = zeros({4, 4}, at::kCUDA);
68 auto t = ones({4, 4});
70 auto wha2 = zeros({4, 4}).add(t).sum();
71 ASSERT_EQ(wha2.item<
double>(), 16.0);
73 ASSERT_EQ(t.sizes()[0], 4);
74 ASSERT_EQ(t.sizes()[1], 4);
75 ASSERT_EQ(t.strides()[0], 4);
76 ASSERT_EQ(t.strides()[1], 1);
79 Tensor x = randn({1, 10}, options);
80 Tensor prev_h = randn({1, 20}, options);
81 Tensor W_h = randn({20, 20}, options);
82 Tensor W_x = randn({20, 10}, options);
83 Tensor i2h = at::mm(W_x, x.t());
84 Tensor h2h = at::mm(W_h, prev_h.t());
85 Tensor next_h = i2h.add(h2h);
86 next_h = next_h.tanh();
88 ASSERT_ANY_THROW(
Tensor{}.item());
93 auto r = CUDA(Float).copy(next_h);
94 ASSERT_TRUE(CPU(Float).copy(r).equal(next_h));
96 ASSERT_NO_THROW(randn({10, 10, 2}, options));
99 ASSERT_EQ(scalar_to_tensor(bar).scalar_type(), kDouble);
100 ASSERT_EQ(scalar_to_tensor(what).scalar_type(), kLong);
101 ASSERT_EQ(scalar_to_tensor(ones({}).item()).scalar_type(), kDouble);
103 if (x.scalar_type() != ScalarType::Half) {
104 AT_DISPATCH_ALL_TYPES(x.scalar_type(),
"foo", [&] {
106 std::stringstream ss;
108 ss <<
"hello, dispatch" << x.type().toString() << s <<
"\n");
109 auto data = (scalar_t*)x.data_ptr();
116 auto x = ones({1, 2}, options);
117 ASSERT_ANY_THROW(x.item<
float>());
119 auto float_one = ones({}, options);
120 ASSERT_EQ(float_one.item<
float>(), 1);
121 ASSERT_EQ(float_one.item<int32_t>(), 1);
122 ASSERT_EQ(float_one.item<
at::Half>(), 1);
Scalar represents a 0-dimensional tensor which contains a single element.
Flush-To-Zero and Denormals-Are-Zero mode.