2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 4 #include <ATen/cuda/CUDAConfig.h> 6 #include <ATen/cudnn/Descriptors.h> 10 #if !AT_CUDNN_ENABLED() 12 namespace at {
namespace native {
16 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
const Tensor& log_probs,
const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK,
bool deterministic,
bool zero_infinity) {
17 AT_ERROR(
"cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support");
22 #else // AT_CUDNN_ENABLED 24 #include <ATen/cudnn/Descriptors.h> 25 #include <ATen/cudnn/Types.h> 26 #include <ATen/cudnn/Utils.h> 28 #include <ATen/TensorUtils.h> 30 namespace at {
namespace native {
36 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
const Tensor& log_probs_t,
const Tensor& targets_t, IntArrayRef input_lengths_, IntArrayRef target_lengths_, int64_t BLANK,
bool deterministic,
bool zero_infinity) {
38 CheckedFrom c =
"cudnn_ctc_loss";
39 TensorArg log_probs { log_probs_t,
"log_probs", 1 };
40 TensorArg targets { targets_t,
"targets", 2 };
41 checkDim(c, log_probs, 3);
42 checkScalarType(c, log_probs, kFloat);
43 checkDim(c, targets, 1);
44 checkScalarType(c, targets, kInt);
45 checkContiguous(c, targets);
46 checkBackend(c, {*log_probs}, Backend::CUDA);
47 checkBackend(c, {*targets}, Backend::CPU);
48 int64_t batch_size = log_probs->size(1);
49 AT_CHECK(input_lengths_.size() == batch_size,
"input_lengths needs to have size to match batch_size");
50 AT_CHECK(target_lengths_.size() == batch_size,
"target_lengths needs to have size to match batch_size");
52 std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
53 std::vector<int> target_lengths(target_lengths_.begin(), target_lengths_.end());
55 setCuDNNStreamToCurrent();
56 AT_CHECK(BLANK == 0,
"blank must be label 0 for cudnn_ctc_loss");
61 auto handle = getCudnnHandle();
63 cudnnCTCLossAlgo_t algo = (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
65 Tensor probs = log_probs->softmax(2);
66 TensorDescriptor probs_desc{probs};
67 Tensor grad = at::empty_like(probs);
68 TensorDescriptor grad_desc{grad};
70 CTCLossDescriptor ctc_loss_desc;
71 ctc_loss_desc.set(CUDNN_DATA_FLOAT);
73 size_t workspace_size;
74 AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(handle, probs_desc.desc(), grad_desc.desc(),
75 targets->data<
int>(), target_lengths.data(), input_lengths.data(),
76 algo, ctc_loss_desc.desc(), &workspace_size));
79 Tensor workspace = at::empty(workspace_size, log_probs->options().dtype(kByte));
80 Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
82 AT_CUDNN_CHECK(cudnnCTCLoss(handle, probs_desc.desc(), probs.data_ptr(),
83 targets->data<
int>(), target_lengths.data(), input_lengths.data(),
84 costs.data_ptr(), grad_desc.desc(), grad.data_ptr(), algo,
85 ctc_loss_desc.desc(), workspace.data_ptr(), workspace_size));
87 return std::make_tuple(costs, grad);
Flush-To-Zero and Denormals-Are-Zero mode.