Caffe2 - C++ API
A deep learning, cross platform ML framework
ctc_beam_search_decoder_op.cc
1 #include "caffe2/operators/ctc_beam_search_decoder_op.h"
2 
3 namespace caffe2 {
4 
5 namespace {
6 
7 const float* getTensorDataPtr(const Tensor& tensor, int t, int n) {
8  const auto dims = tensor.sizes();
9  CAFFE_ENFORCE_EQ(dims.size(), 3);
10  int offset = (t * dims[1] + n) * dims[2];
11  CAFFE_ENFORCE_LT(offset, tensor.numel());
12  return tensor.template data<float>() + offset;
13 }
14 
15 } // namespace
16 
17 template <>
18 bool CTCBeamSearchDecoderOp<CPUContext>::RunOnDevice() {
19  // shape: max_activation_length x batch_size x alphabet_size
20  auto& inputs = Input(INPUTS);
21  // shape: batch_size
22 
23  // shape: sum over all decoded_length
24 
25  const auto inputs_dims = inputs.sizes();
26  int32_t max_activation_length = inputs_dims[0];
27  int32_t batch_size = inputs_dims[1];
28  int32_t alphabet_size = inputs_dims[2];
29  // [batch_size]
30  const int* seq_len_data =
31  (InputSize() == 2) ? Input(SEQ_LEN).data<int>() : nullptr;
32 
33  vector<int32_t> values_cache;
34  auto* output_len =
35  Output(OUTPUT_LEN, vector<int64_t>{batch_size}, at::dtype<int>());
36  int* output_len_data = output_len->mutable_data<int>();
37 
38  for (int32_t i = 0; i < batch_size; ++i) {
39  const int32_t activation_length =
40  (seq_len_data) ? seq_len_data[i] : max_activation_length;
41  std::multimap<float, vector<int32_t>, std::greater<float>> A_next_inv;
42  // For a given time step, Pb maps prefixes to the probability of all
43  // candidate sequences that end in a blank and Pnb maps prefixes to the
44  // probability of all candidate sequences that don't end in a blank.
45  vector<std::map<vector<int32_t>, float>> Pb(
46  activation_length + 1, std::map<vector<int32_t>, float>());
47  vector<std::map<vector<int32_t>, float>> Pnb(
48  activation_length + 1, std::map<vector<int32_t>, float>());
49  set<vector<int32_t>> A_prev;
50  Pb[0][vector<int32_t>()] = 1;
51  Pnb[0][vector<int32_t>()] = 0;
52  A_prev.insert(vector<int32_t>());
53 
54  for (int t = 0; t < activation_length; t++) {
55  const float* ctc = getTensorDataPtr(inputs, t, i);
56 
57  vector<int32_t> pruned_alpha;
58  for (int32_t c = 0; c < alphabet_size; c++) {
59  if (ctc[c] > prune_threshold_) {
60  pruned_alpha.push_back(c);
61  }
62  }
63 
64  // If the pruned alphabet is empty, don't use pruning.
65  if (pruned_alpha.size() == 0) {
66  pruned_alpha = vector<int32_t>(alphabet_size);
67  std::iota(pruned_alpha.begin(), pruned_alpha.end(), 0);
68  }
69 
70  for (auto const& l : A_prev) {
71  // We skip the code handling the end character from the article since
72  // our system does not support an end character.
73 
74  for (auto const c : pruned_alpha) {
75  // Assumption: blank character always mapped to index 0
76  if (c == 0) {
77  Pb[t + 1][l] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
78  } else {
79  vector<int32_t> l_plus = vector<int32_t>(l);
80  l_plus.push_back(c);
81  if (l.size() > 0 && c == l.back()) {
82  Pnb[t + 1][l_plus] += ctc[c] * Pb[t][l];
83  Pnb[t + 1][l] += ctc[c] * Pnb[t][l];
84  } else {
85  Pnb[t + 1][l_plus] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
86  }
87 
88  if (A_prev.find(l_plus) == A_prev.end()) {
89  Pb[t + 1][l_plus] += ctc[0] * (Pb[t][l_plus] + Pnb[t][l_plus]);
90  Pnb[t + 1][l_plus] += ctc[c] * Pnb[t][l_plus];
91  }
92  }
93  }
94  }
95 
96  std::map<vector<int32_t>, float> A_next(Pb[t + 1]);
97  for (auto& it : Pnb[t + 1]) {
98  A_next[it.first] += it.second;
99  }
100  A_next_inv.clear();
101  for (auto& it : A_next) {
102  A_next_inv.insert({it.second, it.first});
103  }
104 
105  A_prev.clear();
106  auto it = A_next_inv.begin();
107  for (int j = 0; j < beam_width_; j++) {
108  if (it == A_next_inv.end()) {
109  break;
110  }
111  A_prev.insert(it->second);
112  it++;
113  }
114  }
115 
116  vector<int32_t> decoded =
117  (A_next_inv.empty()) ? vector<int32_t>() : A_next_inv.begin()->second;
118 
119  output_len_data[i] = decoded.size();
120  values_cache.insert(values_cache.end(), decoded.begin(), decoded.end());
121  }
122 
123  int32_t cache_size = values_cache.size();
124  auto* values = Output(VALUES, vector<int64_t>{cache_size}, at::dtype<int>());
125  int* values_data = values->mutable_data<int>();
126  for (int i = 0; i < values_cache.size(); ++i) {
127  values_data[i] = values_cache.at(i);
128  }
129  values_cache.clear();
130 
131  return true;
132 }
133 
134 REGISTER_CPU_OPERATOR(CTCBeamSearchDecoder, CTCBeamSearchDecoderOp<CPUContext>);
135 OPERATOR_SCHEMA(CTCBeamSearchDecoder)
136  .NumInputs(1, 2)
137  .NumOutputs(2)
138  .SetDoc(
139  "Prefix beam search decoder for connectionist temporal classification.")
140  .Arg(
141  "beam_width",
142  "Maximum number of candidates to carry over to next activation step.")
143  .Arg(
144  "prune_threshold",
145  "Probability threshold below which outputs are ignored.")
146  .Input(
147  0,
148  "INPUTS",
149  "3D float Tensor sized [max_activation_length, batch_size, alphabet_size] "
150  "of network logits (before softmax application).")
151  .Input(
152  1,
153  "SEQ_LEN",
154  "(optional) 1D int vector containing sequence lengths, "
155  "having size [batch_size] "
156  "seq_len will be set to max_time if not provided.")
157  .Output(
158  0,
159  "OUTPUT_LEN",
160  "Output_len matrix size (batch_size). "
161  "Each index stores final output length of its corresponding batch item.")
162  .Output(
163  1,
164  "VALUES",
165  "Values vector, size (total_decoded_outputs). "
166  "The flattened vector of final output sequences, in batch order.")
167  .InheritOnnxSchema();
168 SHOULD_NOT_DO_GRADIENT(CTCBeamSearchDecoder);
169 
170 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13