2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 5 #if !AT_MKLDNN_ENABLED() 7 namespace at {
namespace native {
11 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
12 AT_ERROR(
"mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
17 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool bias_defined) {
18 AT_ERROR(
"mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
21 std::tuple<at::Tensor,at::Tensor> mkldnn_convolution_backward_weights(
23 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool bias_defined) {
24 AT_ERROR(
"mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
27 std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
29 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
30 AT_ERROR(
"mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
35 #else // AT_MKLDNN_EBABLED 37 #include <ATen/mkldnn/Runtime.h> 41 namespace at {
namespace native {
43 constexpr
int input_batch_size_dim = 0;
44 constexpr
int input_channels_dim = 1;
45 constexpr
int output_batch_size_dim = 0;
46 constexpr
int output_channels_dim = 1;
47 constexpr
int weight_output_channels_dim = 0;
48 constexpr
int weight_input_channels_dim = 1;
51 constexpr
int max_dim = 3;
53 static std::vector<int64_t> conv_output_size(
54 IntArrayRef input_size, IntArrayRef weight_size,
55 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
57 auto dim = input_size.size();
58 std::vector<int64_t> output_size(dim);
59 output_size[0] = input_size[input_batch_size_dim];
60 output_size[1] = weight_size[weight_output_channels_dim];
61 for (
size_t d = 2; d < dim; ++d) {
62 auto kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
63 output_size[d] = (input_size[d] + (2 * padding[d - 2])
64 - kernel) / stride[d - 2] + 1;
71 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
73 auto output = at::empty(conv_output_size(
74 input.sizes(), weight.sizes(), padding, stride, dilation, groups), input.
options());
76 auto cpu_engine = CpuEngine::Instance().get_engine();
80 int32_t n = input.size(0);
81 int32_t ic = input.size(1);
82 int32_t ih = input.size(2);
83 int32_t iw = input.size(3);
85 int32_t oc = output.size(1);
86 int32_t oh = output.size(2);
87 int32_t ow = output.size(3);
89 int32_t kh = weight.size(2);
90 int32_t kw = weight.size(3);
92 int32_t sh = stride[0];
93 int32_t sw = stride[1];
94 int32_t ph = padding[0];
95 int32_t pw = padding[1];
97 auto data_t = memory::data_type::f32;
98 auto format_any = memory::format::any;
99 auto format_nchw = memory::format::nchw;
100 auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
101 auto format_x = memory::format::x;
103 memory::dims input_tz = {n, ic, ih, iw};
104 memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
105 memory::dims bias_tz = {oc};
106 memory::dims output_tz = {n, oc, oh, ow};
107 memory::dims _stride = {sh, sw};
108 memory::dims _padding = {ph, pw};
110 auto input_md = memory::desc({input_tz}, data_t, format_any);
111 auto weight_md = memory::desc({weight_tz}, data_t, format_any);
112 auto bias_md = memory::desc({bias_tz}, data_t, format_any);
113 auto output_md = memory::desc({output_tz}, data_t, format_any);
115 std::shared_ptr<convolution_forward::desc> conv_forward_desc;
116 if (bias.defined()) {
117 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
118 convolution_direct, input_md, weight_md, bias_md, output_md,
119 _stride, _padding, _padding, padding_kind::zero));
121 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
122 convolution_direct, input_md, weight_md, output_md,
123 _stride, _padding, _padding, padding_kind::zero));
126 std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
127 conv_forward_pd.reset(
new convolution_forward::primitive_desc(
128 *conv_forward_desc, cpu_engine));
130 auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
132 auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
134 auto output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
137 std::vector<primitive> net;
139 auto input_pd = conv_forward_pd->src_primitive_desc();
140 auto input_memory = input_usr_memory;
141 if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
142 input_memory = memory(input_pd);
143 net.push_back(reorder(input_usr_memory, input_memory));
146 auto weight_pd = conv_forward_pd->weights_primitive_desc();
147 auto weight_memory = weight_usr_memory;
148 if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
149 weight_memory = memory(weight_pd);
150 net.push_back(reorder(weight_usr_memory, weight_memory));
153 auto output_pd = conv_forward_pd->dst_primitive_desc();
154 auto output_memory = output_usr_memory;
155 if (output_usr_memory.get_primitive_desc() != memory::primitive_desc(output_pd)) {
156 output_memory = memory(output_pd);
159 std::shared_ptr<convolution_forward> conv_forward;
160 std::shared_ptr<memory> bias_usr_memory;
161 if (bias.defined()) {
162 bias_usr_memory.reset(
new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
164 conv_forward.reset(
new convolution_forward(*conv_forward_pd, input_memory,
165 weight_memory, *bias_usr_memory, output_memory));
167 conv_forward.reset(
new convolution_forward(*conv_forward_pd, input_memory,
168 weight_memory, output_memory));
170 net.push_back(*conv_forward);
172 if (output_memory != output_usr_memory) {
173 net.push_back(reorder(output_memory, output_usr_memory));
176 Stream::Instance().get_stream().submit(net);
181 Tensor mkldnn_convolution_backward_input(
183 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool bias_defined)
185 auto grad_input = at::empty(input_size, grad_output.
options());
187 auto cpu_engine = CpuEngine::Instance().get_engine();
191 int32_t n = grad_input.size(0);
192 int32_t ic = grad_input.size(1);
193 int32_t ih = grad_input.size(2);
194 int32_t iw = grad_input.size(3);
196 int32_t oc = grad_output.size(1);
197 int32_t oh = grad_output.size(2);
198 int32_t ow = grad_output.size(3);
200 int32_t kh = weight.size(2);
201 int32_t kw = weight.size(3);
203 int32_t sh = stride[0];
204 int32_t sw = stride[1];
205 int32_t ph = padding[0];
206 int32_t pw = padding[1];
208 auto data_t = memory::data_type::f32;
209 auto format_any = memory::format::any;
210 auto format_nchw = memory::format::nchw;
211 auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
213 memory::dims input_tz = {n, ic, ih, iw};
214 memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
215 memory::dims bias_tz = {oc};
216 memory::dims output_tz = {n, oc, oh, ow};
217 memory::dims _stride = {sh, sw};
218 memory::dims _padding = {ph, pw};
220 auto input_md = memory::desc({input_tz}, data_t, format_any);
221 auto weight_md = memory::desc({weight_tz}, data_t, format_any);
222 auto bias_md = memory::desc({bias_tz}, data_t, format_any);
223 auto output_md = memory::desc({output_tz}, data_t, format_any);
226 std::shared_ptr<convolution_forward::desc> conv_forward_desc;
228 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
229 convolution_direct, input_md, weight_md, bias_md, output_md,
230 _stride, _padding, _padding, padding_kind::zero));
232 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
233 convolution_direct, input_md, weight_md, output_md,
234 _stride, _padding, _padding, padding_kind::zero));
237 std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
238 conv_forward_pd.reset(
new convolution_forward::primitive_desc(
239 *conv_forward_desc, cpu_engine));
241 std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
242 conv_backward_data_desc.reset(
new convolution_backward_data::desc(
243 convolution_direct, input_md, weight_md, output_md,
244 _stride, _padding, _padding, padding_kind::zero));
246 std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
247 conv_backward_data_pd.reset(
new convolution_backward_data::primitive_desc(
248 *conv_backward_data_desc, cpu_engine, *conv_forward_pd));
250 auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
251 grad_output.data_ptr());
252 auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
254 auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
255 grad_input.data_ptr());
257 std::vector<primitive> net;
259 auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc();
260 auto grad_output_memory = grad_output_usr_memory;
261 if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
262 grad_output_memory = memory(grad_output_pd);
263 net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
266 auto weight_pd = conv_backward_data_pd->weights_primitive_desc();
267 auto weight_memory = weight_usr_memory;
268 if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
269 weight_memory = memory(weight_pd);
270 net.push_back(reorder(weight_usr_memory, weight_memory));
273 auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc();
274 auto grad_input_memory = grad_input_usr_memory;
275 if (grad_input_memory.get_primitive_desc() != memory::primitive_desc(grad_input_pd)) {
276 grad_input_memory = memory(grad_input_pd);
279 std::shared_ptr<convolution_backward_data> conv_backward_data;
280 conv_backward_data.reset(
new convolution_backward_data(*conv_backward_data_pd,
281 grad_output_memory, weight_memory, grad_input_memory));
282 net.push_back(*conv_backward_data);
284 if (grad_input_memory != grad_input_usr_memory) {
285 net.push_back(reorder(grad_input_memory, grad_input_usr_memory));
288 Stream::Instance().get_stream().submit(net);
293 std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
295 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool bias_defined)
297 auto grad_weight = at::empty(weight_size, grad_output.
options());
301 grad_bias = at::empty({grad_output.size(1)}, grad_output.
options());
304 auto cpu_engine = CpuEngine::Instance().get_engine();
308 int32_t n = input.size(0);
309 int32_t ic = input.size(1);
310 int32_t ih = input.size(2);
311 int32_t iw = input.size(3);
313 int32_t oc = grad_output.size(1);
314 int32_t oh = grad_output.size(2);
315 int32_t ow = grad_output.size(3);
317 int32_t kh = grad_weight.size(2);
318 int32_t kw = grad_weight.size(3);
320 int32_t sh = stride[0];
321 int32_t sw = stride[1];
322 int32_t ph = padding[0];
323 int32_t pw = padding[1];
325 auto data_t = memory::data_type::f32;
326 auto format_any = memory::format::any;
327 auto format_nchw = memory::format::nchw;
328 auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
329 auto format_x = memory::format::x;
331 memory::dims input_tz = {n, ic, ih, iw};
332 memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
333 memory::dims bias_tz = {oc};
334 memory::dims output_tz = {n, oc, oh, ow};
335 memory::dims _stride = {sh, sw};
336 memory::dims _padding = {ph, pw};
338 memory::desc input_md({input_tz}, data_t, format_any);
339 memory::desc weight_md({weight_tz}, data_t, format_any);
340 memory::desc bias_md({bias_tz}, data_t, format_any);
341 memory::desc output_md({output_tz}, data_t, format_any);
344 std::shared_ptr<convolution_forward::desc> conv_forward_desc;
346 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
347 convolution_direct, input_md, weight_md, bias_md, output_md,
348 _stride, _padding, _padding, padding_kind::zero));
350 conv_forward_desc.reset(
new convolution_forward::desc(prop_kind::forward,
351 convolution_direct, input_md, weight_md, output_md,
352 _stride, _padding, _padding, padding_kind::zero));
355 std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
356 conv_forward_pd.reset(
new convolution_forward::primitive_desc(
357 *conv_forward_desc, cpu_engine));
359 std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
361 conv_backward_weight_desc.reset(
new convolution_backward_weights::desc(
362 convolution_direct, input_md, weight_md, bias_md, output_md,
363 _stride, _padding, _padding, padding_kind::zero));
365 conv_backward_weight_desc.reset(
new convolution_backward_weights::desc(
366 convolution_direct, input_md, weight_md, output_md,
367 _stride, _padding, _padding, padding_kind::zero));
370 std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
371 conv_backward_weight_pd.reset(
new convolution_backward_weights::primitive_desc(
372 *conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
374 auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
376 auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
377 grad_output.data_ptr());
378 auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
379 grad_weight.data_ptr());
380 std::shared_ptr<memory> grad_bias_memory;
382 std::vector<primitive> net;
384 auto input_pd = conv_backward_weight_pd->src_primitive_desc();
385 auto input_memory = input_usr_memory;
386 if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
387 input_memory = memory(input_pd);
388 net.push_back(reorder(input_usr_memory, input_memory));
391 auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc();
392 auto grad_output_memory = grad_output_usr_memory;
393 if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
394 grad_output_memory = memory(grad_output_pd);
395 net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
398 auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc();
399 auto grad_weight_memory = grad_weight_usr_memory;
400 if (grad_weight_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_weight_pd)) {
401 grad_weight_memory = memory(grad_weight_pd);
404 std::shared_ptr<convolution_backward_weights> conv_backward_weight;
406 grad_bias_memory.reset(
new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
407 grad_bias.data_ptr()));
408 conv_backward_weight.reset(
new convolution_backward_weights(*conv_backward_weight_pd,
409 input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
411 conv_backward_weight.reset(
new convolution_backward_weights(*conv_backward_weight_pd,
412 input_memory, grad_output_memory, grad_weight_memory));
415 net.push_back(*conv_backward_weight);
417 if (grad_weight_memory != grad_weight_usr_memory) {
418 net.push_back(reorder(grad_weight_memory, grad_weight_usr_memory));
421 Stream::Instance().get_stream().submit(net);
423 return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
426 std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
428 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
430 Tensor grad_output = grad_output_t.contiguous();
432 Tensor grad_input, grad_weight, grad_bias;
433 if (output_mask[0]) {
434 grad_input = at::mkldnn_convolution_backward_input(
435 input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
437 if (output_mask[1] || output_mask[2]) {
438 std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights(
439 weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
442 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Flush-To-Zero and Denormals-Are-Zero mode.