1 #include <ATen/native/RNN.h> 3 #include <ATen/Config.h> 4 #include <ATen/InitialTensorOptions.h> 5 #include <ATen/MatrixRef.h> 6 #include <ATen/NativeFunctions.h> 7 #include <ATen/TensorUtils.h> 8 #include <ATen/cuda/CUDAConfig.h> 9 #include <ATen/cuda/CUDAEvent.h> 10 #include <ATen/cuda/Exceptions.h> 11 #include <c10/util/Exception.h> 13 #if !AT_CUDNN_ENABLED() 15 namespace at {
namespace native {
19 Tensor _cudnn_rnn_flatten_weight(
20 TensorList weight_arr, int64_t weight_stride0,
22 int64_t fn_mode, int64_t fn_hidden_size,
23 int64_t fn_num_layers,
bool batch_first,
26 AT_ERROR(
"_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support");
29 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
31 TensorList weight, int64_t weight_stride0,
33 int64_t fn_mode, int64_t fn_hidden_size,
34 int64_t fn_num_layers,
bool batch_first,
double fn_dropout,
35 bool fn_train,
bool fn_bidirectional, IntArrayRef fn_batch_sizes,
36 const Tensor& fn_dropout_state
38 AT_ERROR(
"_cudnn_rnn: ATen not compiled with cuDNN support");
41 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
42 const Tensor& input, TensorList weight, int64_t weight_stride0,
const Tensor& weight_buf,
const Tensor& hx,
const Tensor& cx,
45 int64_t mode, int64_t hidden_size,
46 int64_t num_layers,
bool batch_first,
double dropout,
47 bool train,
bool bidirectional, IntArrayRef batch_sizes,
49 std::array<bool, 4> output_mask
51 AT_ERROR(
"_cudnn_rnn_backward: ATen not compiled with cuDNN support");
54 Tensor _cudnn_init_dropout_state(
double dropout,
bool train, int64_t dropout_seed,
const TensorOptions& options) {
55 AT_ERROR(
"_cudnn_init_dropout_state: ATen not compiled with cuDNN support");
60 #else // AT_CUDNN_ENABLED() 62 #include <ATen/cudnn/cudnn-wrapper.h> 63 #include <ATen/cudnn/Descriptors.h> 64 #include <ATen/cudnn/Types.h> 65 #include <ATen/cudnn/Utils.h> 67 namespace at {
namespace native {
72 struct DropoutDescriptorParams {
76 DropoutDescriptorParams() {}
77 void set(
bool train_,
double dropout_,
Tensor dropout_state_) {
80 dropout_state = dropout_state_;
82 DropoutDescriptor descriptor(cudnnHandle_t handle)
const {
83 auto dropout_p = train ? dropout : 0;
84 DropoutDescriptor dropout_desc;
86 dropout_desc.set_no_dropout(handle);
88 dropout_desc.set(handle, dropout_p, dropout_state);
96 struct RNNDescriptorParams {
99 cudnnDirectionMode_t bidirectional;
101 cudnnDataType_t datatype;
102 cudnnDataType_t input_datatype;
103 cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
104 cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
106 int64_t num_directions()
const {
107 return bidirectional ? 2 : 1;
110 void set_mode(int64_t fn_mode) {
113 mode = CUDNN_RNN_RELU;
116 mode = CUDNN_RNN_TANH;
126 std::ostringstream oss;
127 oss <<
"unrecognized cuDNN RNN mode " << fn_mode;
133 void set_bidirectional(
bool fn_bidirectional) {
134 bidirectional = fn_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
137 void set_algo(cudnnRNNAlgo_t algo){
141 void set(int64_t mode, int64_t hidden_size, int64_t num_layers,
bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) {
142 this->set_mode(mode);
143 this->hidden_size = hidden_size;
144 this->num_layers = num_layers;
145 this->set_bidirectional(bidirectional);
146 this->datatype = datatype;
147 this->input_datatype = input_datatype;
151 RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc)
const {
152 RNNDescriptor rnn_desc;
153 rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo);
165 RNNDescriptor descriptor(cudnnHandle_t handle)
const {
166 DropoutDescriptor dropout_desc;
167 dropout_desc.set_no_dropout(handle);
168 return descriptor(handle, std::move(dropout_desc));
174 std::vector<TensorDescriptor> rnn_descriptor_sequence(
const Tensor& tensor, IntArrayRef batch_sizes) {
175 std::vector<TensorDescriptor> descriptors(batch_sizes.size());
178 auto batch_tensor_size = tensor.sizes().vec();
179 for (
auto batch_size : batch_sizes) {
180 batch_tensor_size[0] = batch_size;
183 descriptors[i].set(getCudnnDataType(tensor), batch_tensor_size, tensor.strides(), 3);
189 std::vector<TensorDescriptor> rnn_descriptor(
const Tensor& tensor, int64_t N) {
190 std::vector<TensorDescriptor> descriptors(N);
191 for (int64_t i = 0; i < N; i++) {
192 descriptors[i].set(tensor, 5);
257 struct TensorDescriptorListParams {
258 IntArrayRef batch_sizes;
267 int64_t batch_sizes_sum;
269 bool is_input_packed()
const {
270 return batch_sizes.
size() != 0;
273 void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_,
bool batch_first) {
274 batch_sizes = batch_sizes_;
275 if (is_input_packed()) {
276 seq_length = batch_sizes.size();
277 mini_batch = batch_sizes[0];
280 batch_sizes_sum = input_sizes[0];
281 input_size = input_sizes[1];
284 seq_length = input_sizes[1];
285 mini_batch = input_sizes[0];
287 seq_length = input_sizes[0];
288 mini_batch = input_sizes[1];
290 input_size = input_sizes[2];
293 batch_sizes_sum = -1;
298 std::vector<TensorDescriptor> descriptors(
Tensor x)
const {
299 auto is_input_packed = batch_sizes.size() != 0;
300 if (is_input_packed) {
301 return rnn_descriptor_sequence(x, batch_sizes);
303 return rnn_descriptor(x[0], seq_length);
311 DropoutDescriptorParams dropout;
312 RNNDescriptorParams rnn;
313 TensorDescriptorListParams tensors;
317 struct RNNDescriptors {
318 RNNDescriptor rnn_desc;
321 std::vector<TensorDescriptor> x_descs;
322 std::vector<TensorDescriptor> y_descs;
323 TensorDescriptor hx_desc;
324 TensorDescriptor hy_desc;
325 TensorDescriptor cx_desc;
326 TensorDescriptor cy_desc;
329 rnn_desc = fn.rnn.descriptor(handle, fn.dropout.descriptor(handle));
330 x_descs = fn.tensors.descriptors(x);
331 y_descs = fn.tensors.descriptors(y);
342 std::vector<cudnnTensorDescriptor_t> get_descs(
const std::vector<TensorDescriptor>& descs) {
343 std::vector<cudnnTensorDescriptor_t> r;
344 r.reserve(descs.size());
345 for (
auto& desc : descs) {
346 r.emplace_back(desc.desc());
351 std::vector<cudnnTensorDescriptor_t> get_x_descs() {
352 return get_descs(x_descs);
355 std::vector<cudnnTensorDescriptor_t> get_y_descs() {
356 return get_descs(y_descs);
360 int64_t get_num_weights(cudnnHandle_t handle,
const RNNDescriptor& rnn_desc,
361 const TensorDescriptor& x_desc, cudnnDataType_t datatype) {
363 AT_CUDNN_CHECK(cudnnGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
364 auto elem_size = dataSize(datatype);
365 AT_ASSERTM(weight_size % elem_size == 0,
"cudnnGetRNNParamsSize returned nonsensical weight_size");
366 return weight_size / elem_size;
369 int64_t _num_linear_layers(cudnnRNNMode_t mode) {
380 AT_ERROR(
"unknown cuDNN RNN mode ", mode);
402 std::pair<std::vector<Tensor>,
size_t>
404 cudnnHandle_t handle,
405 const RNNDescriptorParams& rnn,
406 const RNNDescriptor& rnn_desc,
407 const TensorDescriptor& x_desc,
408 const FilterDescriptor& w_desc,
411 auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams };
412 std::vector<Tensor> params;
413 int64_t num_linear_layers = _num_linear_layers(rnn.mode);
414 int64_t num_layers = rnn.num_directions() * rnn.num_layers;
415 size_t cur_offset = 0;
416 size_t global_layer_params_count = 0;
417 for (int64_t layer = 0; layer < num_layers; layer++) {
418 size_t layer_params_count = 0;
419 for (
auto cudnn_method : cudnn_methods) {
420 for (int64_t linear_id = 0; linear_id < num_linear_layers; linear_id++) {
421 FilterDescriptor lin_layer_mat_desc;
422 void* matrix_pointer;
423 AT_CUDNN_CHECK(cudnn_method(
429 weight_buf.data_ptr(),
431 lin_layer_mat_desc.mut_desc(),
434 cudnnDataType_t data_type;
435 cudnnTensorFormat_t format;
437 constexpr
int min_dim = 3;
441 Tensor filter_dim_a = at::empty(min_dim, at::initialTensorOptions().dtype(kInt));
442 AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
443 lin_layer_mat_desc.desc(),
448 filter_dim_a.data<
int>()
451 AT_ASSERTM(nb_dims <= min_dim,
"nb_dims = ", nb_dims,
"; min_dim = ", min_dim);
452 filter_dim_a = filter_dim_a.slice(0, 0, nb_dims);
453 auto elem_size = dataSize(getCudnnDataType(weight_buf));
454 auto offset_bytes = (
char*)matrix_pointer - (
char*)weight_buf.data_ptr();
455 AT_ASSERTM(offset_bytes % elem_size == 0,
"offset_bytes = ", offset_bytes,
"; elem_size = ", elem_size);
456 size_t offset = offset_bytes / elem_size;
463 int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data<
int>();
464 if (linear_id == 0 || linear_id == num_linear_layers / 2) {
465 std::initializer_list<int64_t> size = {
466 mat_numel * num_linear_layers / 2, 1};
469 Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
470 params.emplace_back(std::move(param));
471 layer_params_count++;
473 AT_ASSERTM(cur_offset == offset,
"cur_offset = ", cur_offset,
"; offset = ", offset);
475 cur_offset = offset + mat_numel;
479 global_layer_params_count = layer_params_count;
481 AT_ASSERTM(global_layer_params_count == layer_params_count,
482 "global_layer_params_count = ", global_layer_params_count,
483 "; layer_params_count = ", layer_params_count);
486 return std::make_pair(params, global_layer_params_count);
491 std::vector<void*> get_expected_data_ptrs(
492 const Tensor& weight_buf, cudnnHandle_t handle,
const RNNDescriptorParams& rnn,
493 const RNNDescriptor& rnn_desc,
const TensorDescriptor& x_desc, cudnnDataType_t datatype) {
494 FilterDescriptor w_desc;
495 w_desc.set(weight_buf, 3);
497 int64_t num_linear_layers = _num_linear_layers(rnn.mode);
498 int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers;
499 const auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams };
500 std::vector<void*> data_ptrs;
501 data_ptrs.reserve(num_dir_layers * 2 * 2);
502 for (int64_t layer = 0; layer < num_dir_layers; layer++) {
503 for (
auto cudnn_method : cudnn_methods) {
507 const std::array<int64_t, 2> linear_offsets = { 0, num_linear_layers / 2 };
508 for (int64_t linear_id : linear_offsets) {
509 FilterDescriptor lin_layer_mat_desc;
510 void* matrix_pointer;
511 AT_CUDNN_CHECK(cudnn_method(
517 weight_buf.data_ptr(),
519 lin_layer_mat_desc.mut_desc(),
522 data_ptrs.push_back(matrix_pointer);
529 void _viewOrCopyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to,
bool copy) {
530 AT_ASSERTM(params_from.size(0) == params_to.size(0),
"number of layers mismatch");
531 for (
size_t i = 0; i < params_from.size(0); i++) {
532 auto layer_params_from = params_from[i];
533 auto layer_params_to = params_to[i];
537 for (
auto a = layer_params_from.begin(), b = layer_params_to.begin();
538 a != layer_params_from.end() && b != layer_params_to.end();
540 auto param_from = *a, param_to = *b;
541 AT_ASSERTM(param_from.type() == param_to.type(),
"parameter types mismatch");
543 param_to.copy_(param_from.view_as(param_to));
545 param_from.resize_as_(param_to);
551 void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
552 _viewOrCopyParams(params_from, params_to,
true);
555 void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
556 _viewOrCopyParams(params_from, params_to,
false);
560 std::vector<int64_t> _input_size(
const TensorDescriptorListParams& tensors) {
561 if (tensors.is_input_packed()) {
562 return {tensors.batch_sizes_sum, tensors.input_size};
564 return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
568 std::vector<int64_t> _hidden_size(
const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors) {
569 return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size};
572 std::vector<int64_t> _output_size(
const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors) {
573 if (tensors.is_input_packed()) {
574 return {tensors.batch_sizes_sum, rnn.hidden_size * rnn.num_directions()};
576 return {tensors.seq_length, tensors.mini_batch, rnn.hidden_size * rnn.num_directions()};
580 cudnnRNNAlgo_t get_algo(
const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors,
const Tensor input){
581 #if CUDNN_VERSION < 7200 || CUDA_VERSION < 9010 582 return CUDNN_RNN_ALGO_STANDARD;
584 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
585 const int64_t bsize = tensors.mini_batch;
587 if (prop->major == 7 && prop->minor != 5 && getCudnnDataType(input) == CUDNN_DATA_HALF && !tensors.is_input_packed()) {
588 if (rnn.num_layers == 1 && rnn.hidden_size <= 1024 && rnn.num_directions() == 1 &&
589 rnn.hidden_size % 128 == 0 && tensors.input_size % 128 == 0){
592 if ((bsize % 16 == 0 && bsize != 80 && bsize !=112) || bsize == 8){
593 if ((tensors.seq_length >=40 && bsize <=128) ||
594 (tensors.seq_length >=20 && bsize <=96) ||
595 (tensors.seq_length >=10 && bsize <=32)) {
596 return CUDNN_RNN_ALGO_PERSIST_STATIC;
601 return CUDNN_RNN_ALGO_STANDARD;
605 cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
606 #if CUDNN_VERSION != 7103 609 if (dtype == CUDNN_DATA_HALF) {
610 return CUDNN_DATA_FLOAT;
622 Tensor _cudnn_rnn_flatten_weight(
623 TensorList weight_arr, int64_t weight_stride0,
625 int64_t fn_mode, int64_t fn_hidden_size,
626 int64_t fn_num_layers,
bool batch_first,
627 bool fn_bidirectional
630 AT_CHECK(weight_arr.size() > 0,
631 "_cudnn_rnn_flatten_weight_: cannot flatten empty weight list");
633 auto any_param = weight_arr[0];
634 auto datatype = getCudnnDataType(any_param);
636 RNNDescriptorParams rnn;
637 rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
639 auto handle = getCudnnHandle();
640 RNNDescriptor rnn_desc = rnn.descriptor(handle);
642 TensorGeometry x_geom({1, input_size});
643 TensorDescriptor x_desc;
644 x_desc.set(getCudnnDataType(any_param), x_geom.sizes(), x_geom.strides(), 5);
646 auto num_weights = get_num_weights(handle, rnn_desc, x_desc, datatype);
647 auto weight_buf = at::zeros(num_weights, any_param.options());
649 FilterDescriptor w_desc;
650 w_desc.set(weight_buf, 3);
653 std::vector<Tensor> params_arr;
654 size_t params_stride0;
655 std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);
657 MatrixRef<Tensor> weight{weight_arr,
static_cast<size_t>(weight_stride0)},
658 params{params_arr, params_stride0};
661 _copyParams(weight, params);
664 for (
size_t i = 0; i < weight.size(0); i++) {
665 for (
auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
666 orig_param_it != weight[i].end() && new_param_it != params[i].end();
667 orig_param_it++, new_param_it++) {
668 auto orig_param = *orig_param_it, new_param = *new_param_it;
669 orig_param.set_(new_param.view_as(orig_param));
677 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
679 TensorList weight, int64_t weight_stride0,
681 int64_t fn_mode, int64_t fn_hidden_size,
682 int64_t fn_num_layers,
bool batch_first,
double fn_dropout,
683 bool fn_train,
bool fn_bidirectional, IntArrayRef fn_batch_sizes,
684 const Tensor& fn_dropout_state
687 check_device(input_r, weight, {hx, cx});
688 auto input = input_r;
689 auto weight_buf = weight_buf_r;
690 if (fn_dropout_state.defined()) {
691 auto input_arg = TensorArg(input,
"input", 1);
692 auto dropout_state_arg = TensorArg(fn_dropout_state,
"dropout_states", 15);
693 checkSameGPU(
"cudnn_rnn", input_arg, dropout_state_arg);
696 auto datatype = getCudnnDataType(input);
697 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
698 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
699 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
703 if (fn.rnn.mode != CUDNN_LSTM) {
704 AT_CHECK(!cx.defined(),
705 "rnn: illegal defined cx for non-LSTM RNN");
709 auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
710 if (batch_first && !is_input_packed) {
711 input = input.transpose(0, 1);
714 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
715 auto output_size = _output_size(fn.rnn, fn.tensors);
717 AT_CHECK(hx.is_contiguous(),
718 "rnn: hx is not contiguous");
719 AT_CHECK(!cx.defined() || cx.is_contiguous(),
720 "rnn: cx is not contiguous");
722 auto x = input.contiguous();
723 auto output = at::empty(output_size, input.options());
724 auto hy = at::empty(hidden_size, hx.options());
727 cy = at::empty(hidden_size, cx.options());
729 cy = at::empty({0}, hx.options());
733 auto handle = getCudnnHandle();
734 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
735 fn.rnn.set_algo(algo);
736 RNNDescriptors descs(fn, handle, x, y, hx, cx);
738 FilterDescriptor w_desc;
739 if (!weight_buf.defined()) {
740 auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
741 weight_buf = at::empty(num_weights, x.options());
742 w_desc.set(weight_buf, 3);
744 std::vector<Tensor> params;
745 size_t params_stride0;
746 std::tie(params, params_stride0) = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
747 _copyParams(MatrixRef<Tensor>{weight,
static_cast<size_t>(weight_stride0)},
748 MatrixRef<Tensor>{params, params_stride0});
750 w_desc.set(weight_buf, 3);
753 AT_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
754 "Expected cell size ", IntArrayRef{hidden_size},
", got ", cx.sizes());
756 size_t workspace_size;
757 auto x_descs_arr = descs.get_x_descs();
758 auto y_descs_arr = descs.get_y_descs();
759 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
761 descs.rnn_desc.desc(),
762 fn.tensors.seq_length,
766 Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
773 AT_CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
775 descs.rnn_desc.desc(),
776 fn.tensors.seq_length,
780 reserve = at::empty(reserve_size, input.options().dtype(kByte));
781 AT_CUDNN_CHECK(cudnnRNNForwardTraining(
783 descs.rnn_desc.desc(),
784 fn.tensors.seq_length,
785 x_descs_arr.data(), x.data_ptr(),
786 descs.hx_desc.desc(), hx.data_ptr(),
787 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() :
nullptr,
788 w_desc.desc(), weight_buf.data_ptr(),
789 y_descs_arr.data(), y.data_ptr(),
790 descs.hy_desc.desc(), hy.data_ptr(),
791 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() :
nullptr,
792 workspace.data_ptr(), workspace.size(0),
793 reserve.data_ptr(), reserve.size(0)
796 reserve = at::empty({0}, input.options().dtype(kByte));
797 AT_CUDNN_CHECK(cudnnRNNForwardInference(
799 descs.rnn_desc.desc(),
800 fn.tensors.seq_length,
801 x_descs_arr.data(), x.data_ptr(),
802 descs.hx_desc.desc(), hx.data_ptr(),
803 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() :
nullptr,
804 w_desc.desc(), weight_buf.data_ptr(),
805 y_descs_arr.data(), y.data_ptr(),
806 descs.hy_desc.desc(), hy.data_ptr(),
807 descs.cy_desc.desc(), cy.defined() ? cy.data_ptr() :
nullptr,
808 workspace.data_ptr(), workspace.size(0)
813 if (batch_first && !is_input_packed) {
814 output.transpose_(0, 1);
817 return std::make_tuple(output, hy, cy, reserve, weight_buf);
820 std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
824 int64_t fn_mode, int64_t fn_hidden_size,
825 int64_t fn_num_layers,
bool batch_first,
double fn_dropout,
826 bool fn_train,
bool fn_bidirectional, IntArrayRef fn_batch_sizes,
827 const Tensor& fn_dropout_state,
const Tensor& fn_reserve,
828 std::array<bool, 3> output_mask
831 auto input = input_r;
832 auto grad_output = grad_output_r;
833 auto output = output_r;
836 auto datatype = getCudnnDataType(input);
837 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
838 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
839 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
842 auto handle = getCudnnHandle();
844 if (fn.rnn.mode != CUDNN_LSTM) {
845 AT_CHECK(!cx.defined(),
846 "rnn: illegal defined cx for non-LSTM RNN");
849 auto is_input_packed = fn_batch_sizes.size() != 0;
850 if (batch_first && !is_input_packed) {
851 input = input.transpose(0, 1);
852 grad_output = grad_output.transpose(0, 1);
853 output = output.transpose(0, 1);
856 auto input_size = _input_size(fn.tensors);
857 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
858 auto output_size = _output_size(fn.rnn, fn.tensors);
860 AT_CHECK(hx.is_contiguous(),
861 "rnn: hx is not contiguous");
862 AT_CHECK(!cx.defined() || cx.is_contiguous(),
863 "rnn: cx is not contiguous");
865 auto x = input.contiguous();
866 auto dy = grad_output.contiguous();
869 auto dx = at::empty(input.sizes(), input.options());
870 auto dhy = grad_hy.contiguous().view(hidden_size);
871 auto dcy = grad_cy.defined() ? grad_cy.contiguous().view(hidden_size) :
Tensor();
872 auto dhx = at::empty(hidden_size, hx.options());
873 AT_ASSERTM(cx.defined() || !output_mask[2],
"illegally required grad of cx for non-LSTM RNN");
874 auto dcx = cx.defined() ? at::empty(hidden_size, cx.options()) :
Tensor();
877 "cudnn RNN backward can only be called in training mode");
879 AT_CHECK(input.sizes().equals(input_size),
880 "Expected input size ", IntArrayRef{input_size},
", got ", input.sizes());
881 AT_CHECK(output.sizes().equals(output_size),
882 "Expected output size ", IntArrayRef{output_size},
", got ", output.sizes());
884 AT_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
885 "Expected hidden size ", IntArrayRef{hidden_size},
", got ", hx.sizes());
886 AT_CHECK(!cx.defined() || cx.sizes().equals(hidden_size),
887 "Expected cell size ", IntArrayRef{hidden_size},
", got ", cx.sizes());
888 AT_CHECK(!dhy.defined() || dhy.sizes().equals(hidden_size),
889 "Expected d_hidden size ", IntArrayRef{hidden_size},
", got ", dhy.sizes());
890 AT_CHECK(!dcy.defined() || dcy.sizes().equals(hidden_size),
891 "Expected d_cell size ", IntArrayRef{hidden_size},
", got ", dcy.sizes());
893 AT_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
894 "Gradients aren't CUDA tensors");
896 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
897 fn.rnn.set_algo(algo);
898 RNNDescriptors descs(fn, handle, x, y, hx, cx);
900 FilterDescriptor w_desc;
901 w_desc.set(weight_buf, 3);
903 size_t workspace_size;
904 auto x_descs_arr = descs.get_x_descs();
905 auto y_descs_arr = descs.get_y_descs();
906 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
908 descs.rnn_desc.desc(),
909 fn.tensors.seq_length,
914 Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
916 AT_CUDNN_CHECK(cudnnRNNBackwardData(
918 descs.rnn_desc.desc(),
919 fn.tensors.seq_length,
920 y_descs_arr.data(), y.data_ptr(),
921 y_descs_arr.data(), dy.data_ptr(),
922 descs.hy_desc.desc(), dhy.data_ptr(),
923 descs.cy_desc.desc(), cx.defined() ? dcy.data_ptr() :
nullptr,
924 w_desc.desc(), w.data_ptr(),
925 descs.hx_desc.desc(), hx.data_ptr(),
926 descs.cx_desc.desc(), cx.defined() ? cx.data_ptr() :
nullptr,
927 x_descs_arr.data(), dx.data_ptr(),
928 descs.hx_desc.desc(), dhx.data_ptr(),
929 descs.cx_desc.desc(), cx.defined() ? dcx.data_ptr() :
nullptr,
930 workspace.data_ptr(), workspace.size(0),
931 fn_reserve.data_ptr(), fn_reserve.size(0)
934 if (batch_first && !is_input_packed) {
935 dx = dx.transpose_(0, 1);
938 return std::make_tuple(dx, dhx, dcx);
943 std::vector<Tensor> _cudnn_rnn_backward_weight(
945 const Tensor& input_r, TensorList weight_arr, int64_t weight_stride0,
948 int64_t fn_mode, int64_t fn_hidden_size,
949 int64_t fn_num_layers,
bool batch_first,
double fn_dropout,
950 bool fn_train,
bool fn_bidirectional, IntArrayRef fn_batch_sizes,
951 const Tensor& fn_dropout_state,
const Tensor& fn_reserve
954 MatrixRef<Tensor> weight{ weight_arr,
static_cast<size_t>(weight_stride0) };
956 auto input = input_r;
957 auto output = output_r;
960 auto datatype = getCudnnDataType(input);
961 fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
962 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
963 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
965 auto handle = getCudnnHandle();
967 if (fn.rnn.mode != CUDNN_LSTM) {
968 AT_CHECK(!cx.defined(),
969 "rnn: illegal defined cx for non-LSTM RNN");
972 auto is_input_packed = fn_batch_sizes.size() != 0;
973 if (batch_first && !is_input_packed) {
974 input = input.transpose(0, 1);
975 output = output.transpose(0, 1);
978 auto input_size = _input_size(fn.tensors);
979 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
982 "cudnn RNN backward can only be called in training mode");
984 AT_CHECK(input.sizes().equals(input_size),
985 "Expected input size ", IntArrayRef{input_size},
", got ", input.sizes());
986 AT_CHECK(!hx.defined() || hx.sizes().equals(hidden_size),
987 "Expected hidden size ", IntArrayRef{hidden_size},
", got ", hx.sizes());
992 AT_CHECK(hx.is_contiguous(),
993 "rnn: hx is not contiguous");
994 AT_CHECK(!cx.defined() || cx.is_contiguous(),
995 "rnn: cx is not contiguous");
997 auto x = input.contiguous();
998 const auto& y = output;
999 auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
1001 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
1002 fn.rnn.set_algo(algo);
1003 RNNDescriptors descs(fn, handle, x, y, hx, cx);
1005 FilterDescriptor w_desc;
1006 w_desc.set(weight_buf, 3);
1008 size_t workspace_size;
1009 auto x_descs_arr = descs.get_x_descs();
1010 auto y_descs_arr = descs.get_y_descs();
1011 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
1013 descs.rnn_desc.desc(),
1014 fn.tensors.seq_length,
1018 Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
1020 AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
1022 descs.rnn_desc.desc(),
1023 fn.tensors.seq_length,
1024 x_descs_arr.data(), x.data_ptr(),
1025 descs.hx_desc.desc(), hx.data_ptr(),
1026 y_descs_arr.data(), y.data_ptr(),
1027 workspace.data_ptr(), workspace.size(0),
1028 w_desc.desc(), dw.data_ptr(),
1029 fn_reserve.data_ptr(), fn_reserve.size(0)
1033 std::vector<Tensor> grad_params_arr;
1034 size_t grad_params_stride0;
1035 std::tie(grad_params_arr, grad_params_stride0) = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw);
1036 if (grad_params_stride0 == static_cast<size_t>(weight_stride0)) {
1037 _viewParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
1038 MatrixRef<Tensor>{weight_arr,
static_cast<size_t>(weight_stride0)});
1039 return grad_params_arr;
1041 std::vector<Tensor> grad_weight_arr;
1042 grad_weight_arr.reserve( weight.numel() );
1043 for (
const auto& w : weight_arr) {
1044 grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options()));
1046 _copyParams(MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
1047 MatrixRef<Tensor>{grad_weight_arr,
static_cast<size_t>(weight_stride0)});
1048 return grad_weight_arr;
1054 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
1055 const Tensor& input, TensorList weight, int64_t weight_stride0,
const Tensor& weight_buf,
const Tensor& hx,
const Tensor& cx,
1058 int64_t mode, int64_t hidden_size,
1059 int64_t num_layers,
bool batch_first,
double dropout,
1060 bool train,
bool bidirectional, IntArrayRef batch_sizes,
1062 std::array<bool, 4> output_mask
1065 auto grad_output = grad_output_r.defined() ? grad_output_r : at::zeros_like(output);
1066 auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx);
1067 auto grad_cy = cx.defined() ? (grad_cy_r.defined() ? grad_cy_r : at::zeros_like(cx)) : grad_cy_r;
1071 std::tie(dx, dhx, dcx) = at::native::_cudnn_rnn_backward_input(input, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {output_mask[0], output_mask[1], output_mask[2]});
1072 std::vector<Tensor> dw;
1073 if (output_mask[3]) {
1074 dw = at::native::_cudnn_rnn_backward_weight(input, weight, weight_stride0, weight_buf, hx, cx, output, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve);
1076 return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>{dx, dhx, dcx, dw};
1081 Tensor _cudnn_init_dropout_state(
double dropout,
bool train, int64_t dropout_seed,
const TensorOptions& options) {
1082 auto handle = getCudnnHandle();
1083 DropoutDescriptor dropout_desc;
1084 auto dropout_p = train ? dropout : 0;
1085 dropout_desc.initialize_rng(handle, dropout_p, dropout_seed, options);
1086 return dropout_desc.state;
1096 std::tuple<Tensor, Tensor> unpack_hidden(
const Tensor& hidden) {
1097 return std::make_tuple(hidden,
at::Tensor{});
1100 std::tuple<Tensor, Tensor> unpack_hidden(
const std::tuple<Tensor, Tensor>& hidden) {
1104 template<
typename h
idden_type>
1105 hidden_type pack_hidden(
const Tensor& hx,
const Tensor& cx) {
1106 static_assert(std::is_same<hidden_type, void>::value,
"pack_hidden not implemented for this type");
1107 AT_ERROR(
"NOT IMPLEMENTED");
1112 AT_ASSERT(cx.numel() == 0);
1117 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(
const Tensor& hx,
const Tensor& cx) {
1118 return std::make_tuple(hx, cx);
1121 struct DropoutState {
1140 event->block(cuda::getCurrentCUDAStream());
1152 DropoutState& get_dropout_state(
double dropout_p,
bool train, TensorOptions options) {
1154 static std::vector<DropoutState> ten_dropout_state_cache {
static_cast<size_t>(cuda::getNumGPUs()) };
1155 static std::vector<DropoutState> var_dropout_state_cache {
static_cast<size_t>(cuda::getNumGPUs()) };
1156 static std::mutex state_cache_mut;
1158 int device = cuda::current_device();
1159 std::unique_lock<std::mutex> lock {state_cache_mut};
1160 auto& state = options.is_variable() ? var_dropout_state_cache.at(device)
1161 : ten_dropout_state_cache.at(device);
1162 if (train && dropout_p > 0 && !state.buffer.defined()) {
1163 std::unique_lock<std::mutex> lock {state.mutex};
1164 int64_t seed = at::empty({}, at::kLong).random_().item<int64_t>();
1165 state.buffer = at::_cudnn_init_dropout_state(
1166 dropout_p, train, seed, options.dtype(at::kByte));
1169 state.event.emplace();
1174 Tensor try_get_weight_buf(
1175 const Tensor& input, TensorList parameters,
bool has_biases,
1176 cudnnRNNMode_t mode, int64_t hidden_size, int64_t num_layers,
bool bidirectional) {
1178 auto handle = getCudnnHandle();
1179 auto datatype = getCudnnDataType(input);
1181 RNNDescriptorParams rnn;
1182 rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype);
1183 RNNDescriptor rnn_desc = rnn.descriptor(handle);
1185 TensorGeometry x_geom ({1, input.size(-1)});
1186 TensorDescriptor x_desc;
1187 x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5);
1189 auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype);
1192 auto & any_param = parameters.at(0);
1193 auto param_storage = any_param.storage();
1194 auto weight_buf = at::empty({0}, any_param.options()).set_(param_storage);
1195 if (weight_buf.size(0) < num_params) {
1197 }
else if (weight_buf.size(0) > num_params) {
1198 weight_buf = weight_buf.narrow(0, 0, num_params);
1202 auto expected_data_ptrs = get_expected_data_ptrs(
1203 weight_buf, handle, rnn, rnn_desc, x_desc, datatype);
1205 int64_t num_parameters = parameters.size();
1206 int64_t num_ptrs = expected_data_ptrs.size();
1207 AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2)));
1208 AT_ASSERT(num_ptrs % (has_biases ? 4 : 2) == 0);
1209 for (int64_t param_i = 0, ptr_i = 0;
1211 ptr_i += (has_biases ? 2 : 4), param_i += 2) {
1212 if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr())
return {};
1213 if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr())
return {};
1215 if (!parameters[num_parameters - 1].is_contiguous())
return {};
1219 const char * WEIGHT_FORMAT_WARN =
"RNN module weights are not part of single contiguous " 1220 "chunk of memory. This means they need to be compacted " 1221 "at every call, possibly greatly increasing memory usage. " 1222 "To compact weights again call flatten_parameters().";
1224 template<
typename h
idden_type>
1225 std::pair<Tensor, hidden_type> _cudnn_impl(
1226 const Tensor& input,
const Tensor& _batch_sizes,
const hidden_type& hidden,
1227 TensorList params,
bool has_biases, cudnnRNNMode_t mode,
1228 int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional) {
1230 std::tie(hx, cx) = unpack_hidden(hidden);
1231 int64_t hidden_size = hx.size(2);
1233 auto weight_buf = try_get_weight_buf(
1234 input, params, has_biases, mode, hidden_size, num_layers, bidirectional);
1235 if (!weight_buf.defined()) {
1236 AT_WARN(WEIGHT_FORMAT_WARN);
1239 AT_CHECK(_batch_sizes.dim() == 1,
"batch_sizes tensor should be 1D");
1240 IntArrayRef batch_sizes { _batch_sizes.data<int64_t>(), static_cast<size_t>(_batch_sizes.size(0)) };
1242 auto & dropout_state = get_dropout_state(dropout_p, train, input.options());
1243 std::unique_lock<DropoutState> lock { dropout_state };
1245 auto cudnn_output = at::_cudnn_rnn(
1246 input, params, has_biases ? 4 : 2, weight_buf,
1247 hx, cx, static_cast<int>(mode), hidden_size, num_layers,
false,
1248 dropout_p, train, bidirectional, batch_sizes, dropout_state.buffer);
1250 return {std::get<0>(cudnn_output),
1251 pack_hidden<hidden_type>(std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
1254 template<
typename h
idden_type>
1255 std::pair<Tensor, hidden_type> _cudnn_impl(
1256 const Tensor& input,
const hidden_type& hidden,
1257 TensorList params,
bool has_biases, cudnnRNNMode_t mode,
1258 int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional,
bool batch_first) {
1260 std::tie(hx, cx) = unpack_hidden(hidden);
1261 int64_t hidden_size = hx.size(2);
1263 auto weight_buf = try_get_weight_buf(
1264 input, params, has_biases, mode, hidden_size, num_layers, bidirectional);
1265 if (!weight_buf.defined()) {
1266 AT_WARN(WEIGHT_FORMAT_WARN);
1269 auto & dropout_state = get_dropout_state(dropout_p, train, input.options());
1270 std::unique_lock<DropoutState> lock { dropout_state };
1272 auto cudnn_output = at::_cudnn_rnn(
1273 input, params, has_biases ? 4 : 2, weight_buf,
1274 hx, cx, static_cast<int>(mode), hidden_size, num_layers, batch_first, dropout_p,
1275 train, bidirectional, {}, dropout_state.buffer);
1277 return {std::get<0>(cudnn_output),
1278 pack_hidden<hidden_type>(std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
1281 #define ONE_HIDDEN_RNN(NAME, MODE) \ 1282 void NAME##_cudnn(Tensor& output, Tensor& hy, \ 1283 const Tensor& input, const Tensor& hx, \ 1284 TensorList params, bool has_biases, \ 1285 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { \ 1286 std::tie(output, hy) = _cudnn_impl(input, hx, params, has_biases, \ 1287 MODE, num_layers, dropout_p, train, bidirectional, batch_first); \ 1290 void NAME##_packed_cudnn(Tensor& output, Tensor& hy, \ 1291 const Tensor& data, const Tensor& batch_sizes, const Tensor& hx, \ 1292 TensorList params, bool has_biases, \ 1293 int64_t num_layers, double dropout_p, bool train, bool bidirectional) { \ 1294 std::tie(output, hy) = _cudnn_impl(data, batch_sizes, hx, params, \ 1295 has_biases, MODE, num_layers, dropout_p, train, bidirectional); \ 1298 REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \ 1299 REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn); 1301 ONE_HIDDEN_RNN(gru, CUDNN_GRU)
1302 ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH)
1303 ONE_HIDDEN_RNN(rnn_relu, CUDNN_RNN_RELU)
1306 const
Tensor& input, TensorList hx,
1307 TensorList params,
bool has_biases,
1308 int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional,
bool batch_first) {
1309 auto result = _cudnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
1310 CUDNN_LSTM, num_layers, dropout_p, train, bidirectional, batch_first);
1311 output = result.first;
1312 hy = std::get<0>(result.second);
1313 cy = std::get<1>(result.second);
1317 const Tensor& data,
const Tensor& batch_sizes, TensorList hx,
1318 TensorList params,
bool has_biases,
1319 int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional) {
1320 auto result = _cudnn_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]),
1321 params, has_biases, CUDNN_LSTM, num_layers, dropout_p, train, bidirectional);
1322 output = result.first;
1323 hy = std::get<0>(result.second);
1324 cy = std::get<1>(result.second);
1327 REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn);
1328 REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn);
1334 #endif // AT_CUDNN_ENABLED()
constexpr size_t size() const
size - Get the array size.
Flush-To-Zero and Denormals-Are-Zero mode.