Caffe2 - C++ API
A deep learning, cross platform ML framework
Descriptors.h
1 #pragma once
2 
3 #include <ATen/miopen/Exceptions.h>
4 
5 #include <ATen/miopen/miopen-wrapper.h>
6 #include <ATen/ATen.h>
7 #include <ATen/TensorUtils.h>
8 
9 namespace at { namespace native {
10 
11 inline int dataSize(miopenDataType_t dataType)
12 {
13  switch (dataType) {
14  case miopenHalf: return 2;
15  case miopenFloat: return 4;
16  default: return 8;
17  }
18 }
19 
20 // This function modifies 'stride' in place so that the stride for
21 // dim i is the product of the sizes of dims i+1 to the end.
22 static inline void fixSizeOneDimStride(int dim, const int *size, int *stride) {
23  int64_t z = 1;
24  for(int d = dim-1; d >= 0; d--)
25  {
26  if (size[d] == 1) {
27  stride[d] = z;
28  } else {
29  z *= size[d];
30  }
31  }
32 }
33 
34 template <typename T, miopenStatus_t (*dtor)(T*)>
35 struct DescriptorDeleter {
36  void operator()(T* x) {
37  if (x != nullptr) {
38  MIOPEN_CHECK(dtor(x));
39  }
40  }
41 };
42 
43 // A generic class for wrapping MIOpen descriptor types. All you need
44 // is to give the underlying type the Descriptor_t points to (usually,
45 // if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
46 // the constructor and the destructor. Subclasses are responsible
47 // for defining a set() function to actually set the descriptor.
48 //
49 // Descriptors default construct to a nullptr, and have a descriptor
50 // initialized the first time you call set() or any other initializing
51 // function.
52 template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
53 class Descriptor
54 {
55 public:
56  // Use desc() to access the underlying descriptor pointer in
57  // a read-only fashion. Most client code should use this.
58  // If the descriptor was never initialized, this will return
59  // nullptr.
60  T* desc() const { return desc_.get(); }
61  T* desc() { return desc_.get(); }
62 
63  // Use mut_desc() to access the underlying desciptor pointer
64  // if you intend to modify what it points to (e.g., using
65  // miopenSetFooDescriptor). This will ensure that the descriptor
66  // is initialized. Code in this file will use this function.
67  T* mut_desc() { init(); return desc_.get(); }
68 protected:
69  void init() {
70  if (desc_ == nullptr) {
71  T* raw_desc;
72  MIOPEN_CHECK(ctor(&raw_desc));
73  desc_.reset(raw_desc);
74  }
75  }
76 private:
77  std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
78 };
79 
80 class TensorDescriptor
81  : public Descriptor<miopenTensorDescriptor,
82  &miopenCreateTensorDescriptor,
83  &miopenDestroyTensorDescriptor>
84 {
85 public:
86  TensorDescriptor() {}
87  explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
88  set(t, pad);
89  }
90 
91  void set(const at::Tensor &t, size_t pad = 0);
92  void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
93 
94  void print();
95 
96 private:
97  void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
98  fixSizeOneDimStride(dim, size, stride);
99  MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
100  }
101 };
102 
103 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
104 
105 class FilterDescriptor
106  : public Descriptor<miopenTensorDescriptor,
107  &miopenCreateTensorDescriptor,
108  &miopenDestroyTensorDescriptor>
109 {
110 public:
111  void set(const at::Tensor &t, int64_t pad = 0);
112 
113 private:
114  void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
115  MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
116  }
117 };
118 
119 struct ConvolutionDescriptor
120  : public Descriptor<miopenConvolutionDescriptor,
121  &miopenCreateConvolutionDescriptor,
122  &miopenDestroyConvolutionDescriptor>
123 {
124  void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) {
125  MIOPEN_CHECK(miopenInitConvolutionDescriptor(mut_desc(), c_mode, pad[0], pad[1], stride[0], stride[1], upscale[0], upscale[1]));
126  MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
127  }
128 };
129 
130 union Constant
131 {
132  float f;
133  double d;
134  Constant(miopenDataType_t dataType, double value) {
135  if (dataType == miopenHalf || dataType == miopenFloat) {
136  f = static_cast<float>(value);
137  } else {
138  d = value;
139  }
140  }
141 };
142 
143 }} // namespace
Flush-To-Zero and Denormals-Are-Zero mode.