41 #include <c10/util/C++17.h> 46 namespace c10 {
namespace guts {
49 template<
typename _Tp, std::
size_t _Nm>
51 using _Type = _Tp[_Nm];
53 static constexpr _Tp& _S_ref(
const _Type& __t, std::size_t __n) noexcept {
54 return const_cast<_Tp&
>(__t[__n]);
57 static constexpr _Tp* _S_ptr(
const _Type& __t) noexcept {
58 return const_cast<_Tp*
>(__t);
62 template<
typename _Tp>
64 struct _Type final {};
66 static constexpr _Tp& _S_ref(
const _Type& __t, std::size_t) noexcept {
70 static constexpr _Tp* _S_ptr(
const _Type&) noexcept {
75 [[noreturn]]
inline void __throw_out_of_range(std::string msg) {
76 throw std::out_of_range(std::move(msg));
80 template<
typename _Tp, std::
size_t _Nm>
83 using value_type = _Tp;
84 using pointer = value_type*;
85 using const_pointer =
const value_type*;
86 using reference = value_type&;
87 using const_reference =
const value_type&;
88 using iterator = value_type*;
89 using const_iterator =
const value_type*;
90 using size_type = std::size_t;
91 using difference_type = std::ptrdiff_t;
92 using reverse_iterator = std::reverse_iterator<iterator>;
93 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
98 typename _AT_Type::_Type _M_elems;
104 AT_CPP14_CONSTEXPR
void fill(
const value_type& __u)
105 { std::fill_n(begin(), size(), __u); }
107 AT_CPP14_CONSTEXPR
void swap(
array& __other)
108 { std::swap_ranges(begin(), end(), __other.begin()); }
111 AT_CPP14_CONSTEXPR iterator begin() noexcept
112 {
return iterator(data()); }
114 constexpr const_iterator begin()
const noexcept
115 {
return const_iterator(data()); }
117 AT_CPP14_CONSTEXPR iterator end() noexcept
118 {
return iterator(data() + _Nm); }
120 constexpr const_iterator end()
const noexcept
121 {
return const_iterator(data() + _Nm); }
123 AT_CPP14_CONSTEXPR reverse_iterator rbegin() noexcept
124 {
return reverse_iterator(end()); }
126 constexpr const_reverse_iterator rbegin()
const noexcept
127 {
return const_reverse_iterator(end()); }
129 AT_CPP14_CONSTEXPR reverse_iterator rend() noexcept
130 {
return reverse_iterator(begin()); }
132 constexpr const_reverse_iterator rend()
const noexcept
133 {
return const_reverse_iterator(begin()); }
135 constexpr const_iterator cbegin()
const noexcept
136 {
return const_iterator(data()); }
138 constexpr const_iterator cend()
const noexcept
139 {
return const_iterator(data() + _Nm); }
141 constexpr const_reverse_iterator crbegin()
const noexcept
142 {
return const_reverse_iterator(end()); }
144 constexpr const_reverse_iterator crend()
const noexcept
145 {
return const_reverse_iterator(begin()); }
148 constexpr size_type size()
const noexcept {
return _Nm; }
150 constexpr size_type max_size()
const noexcept {
return _Nm; }
152 constexpr
bool empty()
const noexcept {
return size() == 0; }
155 AT_CPP14_CONSTEXPR reference operator[](size_type __n) noexcept
156 {
return _AT_Type::_S_ref(_M_elems, __n); }
158 constexpr const_reference operator[](size_type __n)
const noexcept
159 {
return _AT_Type::_S_ref(_M_elems, __n); }
161 AT_CPP14_CONSTEXPR reference
at(size_type __n) {
163 detail::__throw_out_of_range(std::string() +
164 "array::at: __n (which is " + to_string(__n) +
") " +
165 ">= _Nm (which is " + to_string(_Nm) +
")");
167 return _AT_Type::_S_ref(_M_elems, __n);
170 constexpr const_reference
at(size_type __n)
const {
173 return __n < _Nm ? _AT_Type::_S_ref(_M_elems, __n)
174 : (detail::__throw_out_of_range(std::string() +
175 "array::at: __n (which is " + to_string(__n) +
") " +
176 ">= _Nm (which is " + to_string(_Nm) +
")"),
177 _AT_Type::_S_ref(_M_elems, 0));
180 AT_CPP14_CONSTEXPR reference front() noexcept
183 constexpr const_reference front()
const noexcept
184 {
return _AT_Type::_S_ref(_M_elems, 0); }
186 AT_CPP14_CONSTEXPR reference back() noexcept
187 {
return _Nm ? *(end() - 1) : *end(); }
189 constexpr const_reference back()
const noexcept
191 return _Nm ? _AT_Type::_S_ref(_M_elems, _Nm - 1)
192 : _AT_Type::_S_ref(_M_elems, 0);
195 AT_CPP14_CONSTEXPR pointer data() noexcept
196 {
return _AT_Type::_S_ptr(_M_elems); }
198 constexpr const_pointer data()
const noexcept
199 {
return _AT_Type::_S_ptr(_M_elems); }
202 #if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201606 203 template<
typename _Tp,
typename... _Up>
204 array(_Tp, _Up...) ->
210 template<
class T,
size_t N>
211 constexpr
inline bool array_equals_(
const array<T, N>& lhs,
const array<T, N>& rhs,
size_t current_index) {
212 return (current_index == N)
214 : (lhs.at(current_index) == rhs.at(current_index) && array_equals_(lhs, rhs, current_index + 1));
216 template<
class T,
size_t N>
218 return (current_index == N)
220 : (lhs.at(current_index) < rhs.at(current_index) || array_less_(lhs, rhs, current_index + 1));
223 template<
typename _Tp, std::
size_t _Nm>
225 {
return detail::array_equals_(__one, __two, 0); }
227 template<
typename _Tp, std::
size_t _Nm>
229 {
return !(__one == __two); }
231 template<
typename _Tp, std::
size_t _Nm>
232 constexpr
inline bool operator<(const array<_Tp, _Nm>& __a,
const array<_Tp, _Nm>& __b)
233 {
return detail::array_less_(__a, __b, 0); }
235 template<
typename _Tp, std::
size_t _Nm>
237 {
return __two < __one; }
239 template<
typename _Tp, std::
size_t _Nm>
240 constexpr
inline bool operator<=(const array<_Tp, _Nm>& __one,
const array<_Tp, _Nm>& __two)
241 {
return !(__one > __two); }
243 template<
typename _Tp, std::
size_t _Nm>
245 {
return !(__one < __two); }
248 template<
typename _Tp, std::
size_t _Nm>
250 { __one.swap(__two); }
252 template<std::
size_t _Int,
typename _Tp, std::
size_t _Nm>
254 static_assert(_Int < _Nm,
"array index is within bounds");
258 template<std::
size_t _Int,
typename _Tp, std::
size_t _Nm>
261 static_assert(_Int < _Nm,
"array index is within bounds");
262 return guts::move(get<_Int>(__arr));
265 template<std::
size_t _Int,
typename _Tp, std::
size_t _Nm>
268 static_assert(_Int < _Nm,
"array index is within bounds");
280 template<
class T,
size_t N,
size_t... I>
282 static_assert(
sizeof...(I) == N-1,
"invariant");
283 return {{get<I+1>(arg)...}};
286 template<
class T,
size_t N>
288 static_assert(N > 0,
"Can only call tail() on an array with at least one element");
289 return detail::tail_(arg, guts::make_index_sequence<N-1>());
293 template<
class T,
size_t N,
size_t... I>
295 return {{guts::forward<T>(head), get<I>(tail)...}};
298 template<
class T,
size_t N>
300 return detail::prepend_(guts::forward<T>(head), tail, guts::make_index_sequence<N>());
311 template<
class T,
size_t N,
size_t... I>
313 return {{arr[I]...}};
317 template<
class T,
size_t N>
318 constexpr
array<T, N> to_array(
const T (&arr)[N]) {
319 return detail::to_array_(arr, guts::make_index_sequence<N>());
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Flush-To-Zero and Denormals-Are-Zero mode.