1 #include "caffe2/operators/ctc_beam_search_decoder_op.h" 7 const float* getTensorDataPtr(
const Tensor& tensor,
int t,
int n) {
8 const auto dims = tensor.sizes();
9 CAFFE_ENFORCE_EQ(dims.size(), 3);
10 int offset = (t * dims[1] + n) * dims[2];
11 CAFFE_ENFORCE_LT(offset, tensor.numel());
12 return tensor.template data<float>() + offset;
18 bool CTCBeamSearchDecoderOp<CPUContext>::RunOnDevice() {
20 auto& inputs = Input(INPUTS);
25 const auto inputs_dims = inputs.sizes();
26 int32_t max_activation_length = inputs_dims[0];
27 int32_t batch_size = inputs_dims[1];
28 int32_t alphabet_size = inputs_dims[2];
30 const int* seq_len_data =
31 (InputSize() == 2) ? Input(SEQ_LEN).data<
int>() :
nullptr;
33 vector<int32_t> values_cache;
35 Output(OUTPUT_LEN, vector<int64_t>{batch_size}, at::dtype<int>());
36 int* output_len_data = output_len->mutable_data<
int>();
38 for (int32_t i = 0; i < batch_size; ++i) {
39 const int32_t activation_length =
40 (seq_len_data) ? seq_len_data[i] : max_activation_length;
41 std::multimap<float, vector<int32_t>, std::greater<float>> A_next_inv;
45 vector<std::map<vector<int32_t>,
float>> Pb(
46 activation_length + 1, std::map<vector<int32_t>,
float>());
47 vector<std::map<vector<int32_t>,
float>> Pnb(
48 activation_length + 1, std::map<vector<int32_t>,
float>());
49 set<vector<int32_t>> A_prev;
50 Pb[0][vector<int32_t>()] = 1;
51 Pnb[0][vector<int32_t>()] = 0;
52 A_prev.insert(vector<int32_t>());
54 for (
int t = 0; t < activation_length; t++) {
55 const float* ctc = getTensorDataPtr(inputs, t, i);
57 vector<int32_t> pruned_alpha;
58 for (int32_t c = 0; c < alphabet_size; c++) {
59 if (ctc[c] > prune_threshold_) {
60 pruned_alpha.push_back(c);
65 if (pruned_alpha.size() == 0) {
66 pruned_alpha = vector<int32_t>(alphabet_size);
67 std::iota(pruned_alpha.begin(), pruned_alpha.end(), 0);
70 for (
auto const& l : A_prev) {
74 for (
auto const c : pruned_alpha) {
77 Pb[t + 1][l] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
79 vector<int32_t> l_plus = vector<int32_t>(l);
81 if (l.size() > 0 && c == l.back()) {
82 Pnb[t + 1][l_plus] += ctc[c] * Pb[t][l];
83 Pnb[t + 1][l] += ctc[c] * Pnb[t][l];
85 Pnb[t + 1][l_plus] += ctc[c] * (Pb[t][l] + Pnb[t][l]);
88 if (A_prev.find(l_plus) == A_prev.end()) {
89 Pb[t + 1][l_plus] += ctc[0] * (Pb[t][l_plus] + Pnb[t][l_plus]);
90 Pnb[t + 1][l_plus] += ctc[c] * Pnb[t][l_plus];
96 std::map<vector<int32_t>,
float> A_next(Pb[t + 1]);
97 for (
auto& it : Pnb[t + 1]) {
98 A_next[it.first] += it.second;
101 for (
auto& it : A_next) {
102 A_next_inv.insert({it.second, it.first});
106 auto it = A_next_inv.begin();
107 for (
int j = 0; j < beam_width_; j++) {
108 if (it == A_next_inv.end()) {
111 A_prev.insert(it->second);
116 vector<int32_t> decoded =
117 (A_next_inv.empty()) ? vector<int32_t>() : A_next_inv.begin()->second;
119 output_len_data[i] = decoded.size();
120 values_cache.insert(values_cache.end(), decoded.begin(), decoded.end());
123 int32_t cache_size = values_cache.size();
124 auto* values = Output(VALUES, vector<int64_t>{cache_size}, at::dtype<int>());
125 int* values_data = values->mutable_data<
int>();
126 for (
int i = 0; i < values_cache.size(); ++i) {
127 values_data[i] = values_cache.at(i);
129 values_cache.clear();
134 REGISTER_CPU_OPERATOR(CTCBeamSearchDecoder, CTCBeamSearchDecoderOp<CPUContext>);
135 OPERATOR_SCHEMA(CTCBeamSearchDecoder)
139 "Prefix beam search decoder for connectionist temporal classification.")
142 "Maximum number of candidates to carry over to next activation step.")
145 "Probability threshold below which outputs are ignored.")
149 "3D float Tensor sized [max_activation_length, batch_size, alphabet_size] " 150 "of network logits (before softmax application).")
154 "(optional) 1D int vector containing sequence lengths, " 155 "having size [batch_size] " 156 "seq_len will be set to max_time if not provided.")
160 "Output_len matrix size (batch_size). " 161 "Each index stores final output length of its corresponding batch item.")
165 "Values vector, size (total_decoded_outputs). " 166 "The flattened vector of final output sequences, in batch order.")
167 .InheritOnnxSchema();
168 SHOULD_NOT_DO_GRADIENT(CTCBeamSearchDecoder);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...