3 #include <ATen/miopen/Exceptions.h> 5 #include <ATen/miopen/miopen-wrapper.h> 7 #include <ATen/TensorUtils.h> 9 namespace at {
namespace native {
11 inline int dataSize(miopenDataType_t dataType)
14 case miopenHalf:
return 2;
15 case miopenFloat:
return 4;
22 static inline void fixSizeOneDimStride(
int dim,
const int *size,
int *stride) {
24 for(
int d = dim-1; d >= 0; d--)
34 template <
typename T, miopenStatus_t (*dtor)(T*)>
35 struct DescriptorDeleter {
36 void operator()(
T* x) {
38 MIOPEN_CHECK(dtor(x));
52 template <
typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
60 T* desc()
const {
return desc_.get(); }
61 T* desc() {
return desc_.get(); }
67 T* mut_desc() { init();
return desc_.get(); }
70 if (desc_ ==
nullptr) {
72 MIOPEN_CHECK(ctor(&raw_desc));
73 desc_.reset(raw_desc);
77 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
80 class TensorDescriptor
81 :
public Descriptor<miopenTensorDescriptor,
82 &miopenCreateTensorDescriptor,
83 &miopenDestroyTensorDescriptor>
87 explicit TensorDescriptor(
const at::Tensor &t,
size_t pad = 0) {
92 void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides,
size_t pad = 0);
97 void set(miopenDataType_t dataType,
int dim,
int* size,
int* stride) {
98 fixSizeOneDimStride(dim, size, stride);
99 MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
103 std::ostream& operator<<(std::ostream & out,
const TensorDescriptor& d);
105 class FilterDescriptor
106 :
public Descriptor<miopenTensorDescriptor,
107 &miopenCreateTensorDescriptor,
108 &miopenDestroyTensorDescriptor>
111 void set(
const at::Tensor &t, int64_t pad = 0);
114 void set(miopenDataType_t dataType,
int dim,
int* size,
int* stride) {
115 MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
119 struct ConvolutionDescriptor
120 :
public Descriptor<miopenConvolutionDescriptor,
121 &miopenCreateConvolutionDescriptor,
122 &miopenDestroyConvolutionDescriptor>
124 void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode,
int dim,
int* pad,
int* stride,
int * upscale ,
int groups) {
125 MIOPEN_CHECK(miopenInitConvolutionDescriptor(mut_desc(), c_mode, pad[0], pad[1], stride[0], stride[1], upscale[0], upscale[1]));
126 MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
134 Constant(miopenDataType_t dataType,
double value) {
135 if (dataType == miopenHalf || dataType == miopenFloat) {
136 f =
static_cast<float>(value);
Flush-To-Zero and Denormals-Are-Zero mode.