Caffe2 - C++ API
A deep learning, cross platform ML framework
LossCTC.cpp
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/Config.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #if AT_CUDNN_ENABLED()
6  #include <ATen/cudnn/Descriptors.h>
7 #endif
8 
9 
10 #if !AT_CUDNN_ENABLED()
11 
12 namespace at { namespace native {
13 
14 // See Note [ATen preprocessor philosophy]
15 
16 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool deterministic, bool zero_infinity) {
17  AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support");
18 }
19 
20 }}
21 
22 #else // AT_CUDNN_ENABLED
23 
24 #include <ATen/cudnn/Descriptors.h>
25 #include <ATen/cudnn/Types.h>
26 #include <ATen/cudnn/Utils.h>
27 
28 #include <ATen/TensorUtils.h>
29 
30 namespace at { namespace native {
31 
32 namespace {
33 
34 } // namespace
35 
36 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(const Tensor& log_probs_t, const Tensor& targets_t, IntArrayRef input_lengths_, IntArrayRef target_lengths_, int64_t BLANK, bool deterministic, bool zero_infinity) {
37  (void)zero_infinity; // only used for backward
38  CheckedFrom c = "cudnn_ctc_loss";
39  TensorArg log_probs { log_probs_t, "log_probs", 1 };
40  TensorArg targets { targets_t, "targets", 2 };
41  checkDim(c, log_probs, 3);
42  checkScalarType(c, log_probs, kFloat);
43  checkDim(c, targets, 1);
44  checkScalarType(c, targets, kInt);
45  checkContiguous(c, targets); // ?
46  checkBackend(c, {*log_probs}, Backend::CUDA);
47  checkBackend(c, {*targets}, Backend::CPU);
48  int64_t batch_size = log_probs->size(1);
49  AT_CHECK(input_lengths_.size() == batch_size, "input_lengths needs to have size to match batch_size");
50  AT_CHECK(target_lengths_.size() == batch_size, "target_lengths needs to have size to match batch_size");
51 
52  std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
53  std::vector<int> target_lengths(target_lengths_.begin(), target_lengths_.end());
54 
55  setCuDNNStreamToCurrent();
56  AT_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
57  // checked in dispatch:
58  // assert other conditions for cudnnCTCLoss: all label lengths <= 256
59  // all input lengths = logprob.size(0)
60 
61  auto handle = getCudnnHandle();
62 
63  cudnnCTCLossAlgo_t algo = (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
64 
65  Tensor probs = log_probs->softmax(2);
66  TensorDescriptor probs_desc{probs};
67  Tensor grad = at::empty_like(probs);
68  TensorDescriptor grad_desc{grad};
69 
70  CTCLossDescriptor ctc_loss_desc;
71  ctc_loss_desc.set(CUDNN_DATA_FLOAT);
72 
73  size_t workspace_size;
74  AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(handle, probs_desc.desc(), grad_desc.desc(),
75  targets->data<int>(), target_lengths.data(), input_lengths.data(),
76  algo, ctc_loss_desc.desc(), &workspace_size));
77 
78 
79  Tensor workspace = at::empty(workspace_size, log_probs->options().dtype(kByte));
80  Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
81 
82  AT_CUDNN_CHECK(cudnnCTCLoss(handle, probs_desc.desc(), probs.data_ptr(),
83  targets->data<int>(), target_lengths.data(), input_lengths.data(),
84  costs.data_ptr(), grad_desc.desc(), grad.data_ptr(), algo,
85  ctc_loss_desc.desc(), workspace.data_ptr(), workspace_size));
86 
87  return std::make_tuple(costs, grad);
88 }
89 
90 
91 }} // namespace at::native
92 
93 #endif
Flush-To-Zero and Denormals-Are-Zero mode.