1 #include <ATen/miopen/Descriptors.h> 4 namespace at {
namespace native {
8 inline miopenDataType_t getDataType(
const at::Tensor& t) {
9 auto scalar_type = t.scalar_type();
10 if (scalar_type == at::kFloat) {
12 }
else if (scalar_type == at::kHalf) {
15 throw std::runtime_error(
"TensorDescriptor only supports float and half tensors");
21 void TensorDescriptor::set(
const at::Tensor &t,
size_t pad) {
22 set(getDataType(t), t.sizes(), t.strides(), pad);
25 static int MIOPEN_DIM_MAX = 4;
27 void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides,
size_t pad) {
28 size_t dim = t_sizes.size();
29 if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
31 #define STR(X) _STR(X) 32 throw std::runtime_error(
"MIOpen supports only up to " STR(MIOPEN_DIM_MAX)
" dimensions");
35 int size[MIOPEN_DIM_MAX];
36 int stride[MIOPEN_DIM_MAX];
37 for (
size_t i = 0; i < dim; ++i) {
38 size[i] =
static_cast<int>(t_sizes[i]);
39 stride[i] =
static_cast<int>(t_strides[i]);
41 for (
size_t i = dim; i < pad; ++i) {
45 set(datatype,
static_cast<int>(std::max(dim, pad)), size, stride);
48 std::string miopenTypeToString(miopenDataType_t dtype) {
55 std::ostringstream oss;
56 oss <<
"(unknown data-type " <<
static_cast<int>(dtype) <<
")";
61 std::ostream& operator<<(std::ostream & out,
const TensorDescriptor& d) {
62 out <<
"TensorDescriptor " <<
static_cast<void*
>(d.desc()) <<
"\n";
64 int dimA[MIOPEN_DIM_MAX];
65 int strideA[MIOPEN_DIM_MAX];
66 miopenDataType_t dtype;
67 miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
68 out <<
" type = " << miopenTypeToString(dtype) <<
"\n";
69 out <<
" nbDims = " << nbDims <<
"\n";
72 for (
auto i : ArrayRef<int>{dimA,
static_cast<size_t>(nbDims)}) {
77 for (
auto i : ArrayRef<int>{strideA,
static_cast<size_t>(nbDims)}) {
84 void TensorDescriptor::print() { std::cout << *
this; }
86 void FilterDescriptor::set(
const at::Tensor &t, int64_t pad) {
87 auto dim = t.ndimension();
88 if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
90 #define STR(X) _STR(X) 91 throw std::runtime_error(
"MIOpen supports only up to " STR(MIOPEN_DIM_MAX)
" dimensions");
94 if (!t.is_contiguous()) {
95 throw std::runtime_error(
"MIOpen filters (a.k.a. weights) must be contiguous");
97 int size[MIOPEN_DIM_MAX];
98 int stride[MIOPEN_DIM_MAX];
99 for (
int i = 0; i < dim; ++i) {
100 size[i] = (int) t.size(i);
102 for (
int i = dim; i < pad; ++i) {
105 for (
int i = dim - 1; i >=0; --i) {
106 stride[i] = (i == dim - 1) ? 1 : stride[i+1] * size[i+1];
108 dim = std::max(dim, pad);
109 set(getDataType(t), (int) dim, size, stride);
Flush-To-Zero and Denormals-Are-Zero mode.