Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/recodebeam.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: recodebeam.cpp | |
| 3 // Description: Beam search to decode from the re-encoded CJK as a sequence of | |
| 4 // smaller numbers in place of a single large code. | |
| 5 // Author: Ray Smith | |
| 6 // | |
| 7 // (C) Copyright 2015, Google Inc. | |
| 8 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 9 // you may not use this file except in compliance with the License. | |
| 10 // You may obtain a copy of the License at | |
| 11 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 12 // Unless required by applicable law or agreed to in writing, software | |
| 13 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 15 // See the License for the specific language governing permissions and | |
| 16 // limitations under the License. | |
| 17 // | |
| 18 /////////////////////////////////////////////////////////////////////// | |
| 19 | |
| 20 #include "recodebeam.h" | |
| 21 | |
| 22 #include "networkio.h" | |
| 23 #include "pageres.h" | |
| 24 #include "unicharcompress.h" | |
| 25 | |
| 26 #include <algorithm> // for std::reverse | |
| 27 | |
| 28 namespace tesseract { | |
| 29 | |
| 30 // The beam width at each code position. | |
| 31 const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = { | |
| 32 5, 10, 16, 16, 16, 16, 16, 16, 16, 16, | |
| 33 }; | |
| 34 | |
| 35 static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"}; | |
| 36 | |
| 37 // Prints debug details of the node. | |
| 38 void RecodeNode::Print(int null_char, const UNICHARSET &unicharset, | |
| 39 int depth) const { | |
| 40 if (code == null_char) { | |
| 41 tprintf("null_char"); | |
| 42 } else { | |
| 43 tprintf("label=%d, uid=%d=%s", code, unichar_id, | |
| 44 unicharset.debug_str(unichar_id).c_str()); | |
| 45 } | |
| 46 tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty, | |
| 47 start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "", | |
| 48 end_of_word ? " End" : "", permuter, code_hash); | |
| 49 if (depth > 0 && prev != nullptr) { | |
| 50 tprintf(" prev:"); | |
| 51 prev->Print(null_char, unicharset, depth - 1); | |
| 52 } else { | |
| 53 tprintf("\n"); | |
| 54 } | |
| 55 } | |
| 56 | |
| 57 // Borrows the pointer, which is expected to survive until *this is deleted. | |
| 58 RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder, | |
| 59 int null_char, bool simple_text, Dict *dict) | |
| 60 : recoder_(recoder), | |
| 61 beam_size_(0), | |
| 62 top_code_(-1), | |
| 63 second_code_(-1), | |
| 64 dict_(dict), | |
| 65 space_delimited_(true), | |
| 66 is_simple_text_(simple_text), | |
| 67 null_char_(null_char) { | |
| 68 if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) { | |
| 69 space_delimited_ = false; | |
| 70 } | |
| 71 } | |
| 72 | |
| 73 RecodeBeamSearch::~RecodeBeamSearch() { | |
| 74 for (auto data : beam_) { | |
| 75 delete data; | |
| 76 } | |
| 77 for (auto data : secondary_beam_) { | |
| 78 delete data; | |
| 79 } | |
| 80 } | |
| 81 | |
| 82 // Decodes the set of network outputs, storing the lattice internally. | |
| 83 void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio, | |
| 84 double cert_offset, double worst_dict_cert, | |
| 85 const UNICHARSET *charset, int lstm_choice_mode) { | |
| 86 beam_size_ = 0; | |
| 87 int width = output.Width(); | |
| 88 if (lstm_choice_mode) { | |
| 89 timesteps.clear(); | |
| 90 } | |
| 91 for (int t = 0; t < width; ++t) { | |
| 92 ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); | |
| 93 DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, | |
| 94 charset); | |
| 95 if (lstm_choice_mode) { | |
| 96 SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t); | |
| 97 } | |
| 98 } | |
| 99 } | |
| 100 void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float> &output, | |
| 101 double dict_ratio, double cert_offset, | |
| 102 double worst_dict_cert, | |
| 103 const UNICHARSET *charset) { | |
| 104 beam_size_ = 0; | |
| 105 int width = output.dim1(); | |
| 106 for (int t = 0; t < width; ++t) { | |
| 107 ComputeTopN(output[t], output.dim2(), kBeamWidths[0]); | |
| 108 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset); | |
| 109 } | |
| 110 } | |
| 111 | |
| 112 void RecodeBeamSearch::DecodeSecondaryBeams( | |
| 113 const NetworkIO &output, double dict_ratio, double cert_offset, | |
| 114 double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) { | |
| 115 for (auto data : secondary_beam_) { | |
| 116 delete data; | |
| 117 } | |
| 118 secondary_beam_.clear(); | |
| 119 if (character_boundaries_.size() < 2) { | |
| 120 return; | |
| 121 } | |
| 122 int width = output.Width(); | |
| 123 unsigned bucketNumber = 0; | |
| 124 for (int t = 0; t < width; ++t) { | |
| 125 while ((bucketNumber + 1) < character_boundaries_.size() && | |
| 126 t >= character_boundaries_[bucketNumber + 1]) { | |
| 127 ++bucketNumber; | |
| 128 } | |
| 129 ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t), | |
| 130 output.NumFeatures(), kBeamWidths[0]); | |
| 131 DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset, | |
| 132 worst_dict_cert, charset); | |
| 133 } | |
| 134 } | |
| 135 | |
| 136 void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs, | |
| 137 int num_outputs, | |
| 138 const UNICHARSET *charset, | |
| 139 int xCoord) { | |
| 140 std::vector<std::pair<const char *, float>> choices; | |
| 141 for (int i = 0; i < num_outputs; ++i) { | |
| 142 if (outputs[i] >= 0.01f) { | |
| 143 const char *character; | |
| 144 if (i + 2 >= num_outputs) { | |
| 145 character = ""; | |
| 146 } else if (i > 0) { | |
| 147 character = charset->id_to_unichar_ext(i + 2); | |
| 148 } else { | |
| 149 character = charset->id_to_unichar_ext(i); | |
| 150 } | |
| 151 size_t pos = 0; | |
| 152 // order the possible choices within one timestep | |
| 153 // beginning with the most likely | |
| 154 while (choices.size() > pos && choices[pos].second > outputs[i]) { | |
| 155 pos++; | |
| 156 } | |
| 157 choices.insert(choices.begin() + pos, | |
| 158 std::pair<const char *, float>(character, outputs[i])); | |
| 159 } | |
| 160 } | |
| 161 timesteps.push_back(choices); | |
| 162 } | |
| 163 | |
| 164 void RecodeBeamSearch::segmentTimestepsByCharacters() { | |
| 165 for (unsigned i = 1; i < character_boundaries_.size(); ++i) { | |
| 166 std::vector<std::vector<std::pair<const char *, float>>> segment; | |
| 167 for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i]; | |
| 168 ++j) { | |
| 169 segment.push_back(timesteps[j]); | |
| 170 } | |
| 171 segmentedTimesteps.push_back(segment); | |
| 172 } | |
| 173 } | |
| 174 std::vector<std::vector<std::pair<const char *, float>>> | |
| 175 RecodeBeamSearch::combineSegmentedTimesteps( | |
| 176 std::vector<std::vector<std::vector<std::pair<const char *, float>>>> | |
| 177 *segmentedTimesteps) { | |
| 178 std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps; | |
| 179 for (auto &segmentedTimestep : *segmentedTimesteps) { | |
| 180 for (auto &j : segmentedTimestep) { | |
| 181 combined_timesteps.push_back(j); | |
| 182 } | |
| 183 } | |
| 184 return combined_timesteps; | |
| 185 } | |
| 186 | |
| 187 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts, | |
| 188 std::vector<int> *ends, | |
| 189 std::vector<int> *char_bounds_, | |
| 190 int maxWidth) { | |
| 191 char_bounds_->push_back(0); | |
| 192 for (unsigned i = 0; i < ends->size(); ++i) { | |
| 193 int middle = ((*starts)[i + 1] - (*ends)[i]) / 2; | |
| 194 char_bounds_->push_back((*ends)[i] + middle); | |
| 195 } | |
| 196 char_bounds_->pop_back(); | |
| 197 char_bounds_->push_back(maxWidth); | |
| 198 } | |
| 199 | |
| 200 // Returns the best path as labels/scores/xcoords similar to simple CTC. | |
| 201 void RecodeBeamSearch::ExtractBestPathAsLabels( | |
| 202 std::vector<int> *labels, std::vector<int> *xcoords) const { | |
| 203 labels->clear(); | |
| 204 xcoords->clear(); | |
| 205 std::vector<const RecodeNode *> best_nodes; | |
| 206 ExtractBestPaths(&best_nodes, nullptr); | |
| 207 // Now just run CTC on the best nodes. | |
| 208 int t = 0; | |
| 209 int width = best_nodes.size(); | |
| 210 while (t < width) { | |
| 211 int label = best_nodes[t]->code; | |
| 212 if (label != null_char_) { | |
| 213 labels->push_back(label); | |
| 214 xcoords->push_back(t); | |
| 215 } | |
| 216 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) { | |
| 217 } | |
| 218 } | |
| 219 xcoords->push_back(width); | |
| 220 } | |
| 221 | |
| 222 // Returns the best path as unichar-ids/certs/ratings/xcoords skipping | |
| 223 // duplicates, nulls and intermediate parts. | |
| 224 void RecodeBeamSearch::ExtractBestPathAsUnicharIds( | |
| 225 bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids, | |
| 226 std::vector<float> *certs, std::vector<float> *ratings, | |
| 227 std::vector<int> *xcoords) const { | |
| 228 std::vector<const RecodeNode *> best_nodes; | |
| 229 ExtractBestPaths(&best_nodes, nullptr); | |
| 230 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords); | |
| 231 if (debug) { | |
| 232 DebugPath(unicharset, best_nodes); | |
| 233 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings, | |
| 234 *xcoords); | |
| 235 } | |
| 236 } | |
| 237 | |
| 238 // Returns the best path as a set of WERD_RES. | |
| 239 void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box, | |
| 240 float scale_factor, bool debug, | |
| 241 const UNICHARSET *unicharset, | |
| 242 PointerVector<WERD_RES> *words, | |
| 243 int lstm_choice_mode) { | |
| 244 words->truncate(0); | |
| 245 std::vector<int> unichar_ids; | |
| 246 std::vector<float> certs; | |
| 247 std::vector<float> ratings; | |
| 248 std::vector<int> xcoords; | |
| 249 std::vector<const RecodeNode *> best_nodes; | |
| 250 std::vector<const RecodeNode *> second_nodes; | |
| 251 character_boundaries_.clear(); | |
| 252 ExtractBestPaths(&best_nodes, &second_nodes); | |
| 253 if (debug) { | |
| 254 DebugPath(unicharset, best_nodes); | |
| 255 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings, | |
| 256 &xcoords); | |
| 257 tprintf("\nSecond choice path:\n"); | |
| 258 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, | |
| 259 xcoords); | |
| 260 } | |
| 261 // If lstm choice mode is required in granularity level 2, it stores the x | |
| 262 // Coordinates of every chosen character, to match the alternative choices to | |
| 263 // it. | |
| 264 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords, | |
| 265 &character_boundaries_); | |
| 266 int num_ids = unichar_ids.size(); | |
| 267 if (debug) { | |
| 268 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings, | |
| 269 xcoords); | |
| 270 } | |
| 271 // Convert labels to unichar-ids. | |
| 272 int word_end = 0; | |
| 273 float prev_space_cert = 0.0f; | |
| 274 for (int word_start = 0; word_start < num_ids; word_start = word_end) { | |
| 275 for (word_end = word_start + 1; word_end < num_ids; ++word_end) { | |
| 276 // A word is terminated when a space character or start_of_word flag is | |
| 277 // hit. We also want to force a separate word for every non | |
| 278 // space-delimited character when not in a dictionary context. | |
| 279 if (unichar_ids[word_end] == UNICHAR_SPACE) { | |
| 280 break; | |
| 281 } | |
| 282 int index = xcoords[word_end]; | |
| 283 if (best_nodes[index]->start_of_word) { | |
| 284 break; | |
| 285 } | |
| 286 if (best_nodes[index]->permuter == TOP_CHOICE_PERM && | |
| 287 (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) || | |
| 288 !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) { | |
| 289 break; | |
| 290 } | |
| 291 } | |
| 292 float space_cert = 0.0f; | |
| 293 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) { | |
| 294 space_cert = certs[word_end]; | |
| 295 } | |
| 296 bool leading_space = | |
| 297 word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE; | |
| 298 // Create a WERD_RES for the output word. | |
| 299 WERD_RES *word_res = | |
| 300 InitializeWord(leading_space, line_box, word_start, word_end, | |
| 301 std::min(space_cert, prev_space_cert), unicharset, | |
| 302 xcoords, scale_factor); | |
| 303 for (int i = word_start; i < word_end; ++i) { | |
| 304 auto *choices = new BLOB_CHOICE_LIST; | |
| 305 BLOB_CHOICE_IT bc_it(choices); | |
| 306 auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1, | |
| 307 1.0f, static_cast<float>(INT16_MAX), 0.0f, | |
| 308 BCC_STATIC_CLASSIFIER); | |
| 309 int col = i - word_start; | |
| 310 choice->set_matrix_cell(col, col); | |
| 311 bc_it.add_after_then_move(choice); | |
| 312 word_res->ratings->put(col, col, choices); | |
| 313 } | |
| 314 int index = xcoords[word_end - 1]; | |
| 315 word_res->FakeWordFromRatings(best_nodes[index]->permuter); | |
| 316 words->push_back(word_res); | |
| 317 prev_space_cert = space_cert; | |
| 318 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) { | |
| 319 ++word_end; | |
| 320 } | |
| 321 } | |
| 322 } | |
| 323 | |
| 324 struct greater_than { | |
| 325 inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const { | |
| 326 return (node1->score > node2->score); | |
| 327 } | |
| 328 }; | |
| 329 | |
| 330 void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs, | |
| 331 const UNICHARSET *charset, | |
| 332 bool secondary) const { | |
| 333 std::vector<std::vector<const RecodeNode *>> topology; | |
| 334 std::unordered_set<const RecodeNode *> visited; | |
| 335 const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_; | |
| 336 // create the topology | |
| 337 for (int step = beam.size() - 1; step >= 0; --step) { | |
| 338 std::vector<const RecodeNode *> layer; | |
| 339 topology.push_back(layer); | |
| 340 } | |
| 341 // fill the topology with depths first | |
| 342 for (int step = beam.size() - 1; step >= 0; --step) { | |
| 343 std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap(); | |
| 344 for (auto &&node : heaps) { | |
| 345 int backtracker = 0; | |
| 346 const RecodeNode *curr = &node.data(); | |
| 347 while (curr != nullptr && !visited.count(curr)) { | |
| 348 visited.insert(curr); | |
| 349 topology[step - backtracker].push_back(curr); | |
| 350 curr = curr->prev; | |
| 351 ++backtracker; | |
| 352 } | |
| 353 } | |
| 354 } | |
| 355 int ct = 0; | |
| 356 unsigned cb = 1; | |
| 357 for (const std::vector<const RecodeNode *> &layer : topology) { | |
| 358 if (cb >= character_boundaries_.size()) { | |
| 359 break; | |
| 360 } | |
| 361 if (ct == character_boundaries_[cb]) { | |
| 362 tprintf("***\n"); | |
| 363 ++cb; | |
| 364 } | |
| 365 for (const RecodeNode *node : layer) { | |
| 366 const char *code; | |
| 367 int intCode; | |
| 368 if (node->unichar_id != INVALID_UNICHAR_ID) { | |
| 369 code = charset->id_to_unichar(node->unichar_id); | |
| 370 intCode = node->unichar_id; | |
| 371 } else if (node->code == null_char_) { | |
| 372 intCode = 0; | |
| 373 code = " "; | |
| 374 } else { | |
| 375 intCode = 666; | |
| 376 code = "*"; | |
| 377 } | |
| 378 int intPrevCode = 0; | |
| 379 const char *prevCode; | |
| 380 float prevScore = 0; | |
| 381 if (node->prev != nullptr) { | |
| 382 prevScore = node->prev->score; | |
| 383 if (node->prev->unichar_id != INVALID_UNICHAR_ID) { | |
| 384 prevCode = charset->id_to_unichar(node->prev->unichar_id); | |
| 385 intPrevCode = node->prev->unichar_id; | |
| 386 } else if (node->code == null_char_) { | |
| 387 intPrevCode = 0; | |
| 388 prevCode = " "; | |
| 389 } else { | |
| 390 prevCode = "*"; | |
| 391 intPrevCode = 666; | |
| 392 } | |
| 393 } else { | |
| 394 prevCode = " "; | |
| 395 } | |
| 396 if (uids) { | |
| 397 tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode, | |
| 398 node->score); | |
| 399 } else { | |
| 400 tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score); | |
| 401 } | |
| 402 } | |
| 403 tprintf("-\n"); | |
| 404 ++ct; | |
| 405 } | |
| 406 tprintf("***\n"); | |
| 407 } | |
| 408 | |
| 409 void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET *unicharset) { | |
| 410 if (character_boundaries_.size() < 2) { | |
| 411 return; | |
| 412 } | |
| 413 // For the first iteration the original beam is analyzed. After that a | |
| 414 // new beam is calculated based on the results from the original beam. | |
| 415 std::vector<RecodeBeam *> ¤tBeam = | |
| 416 secondary_beam_.empty() ? beam_ : secondary_beam_; | |
| 417 character_boundaries_[0] = 0; | |
| 418 for (unsigned j = 1; j < character_boundaries_.size(); ++j) { | |
| 419 std::vector<int> unichar_ids; | |
| 420 std::vector<float> certs; | |
| 421 std::vector<float> ratings; | |
| 422 std::vector<int> xcoords; | |
| 423 int backpath = character_boundaries_[j] - character_boundaries_[j - 1]; | |
| 424 std::vector<tesseract::RecodePair> &heaps = | |
| 425 currentBeam.at(character_boundaries_[j] - 1)->beams_->heap(); | |
| 426 std::vector<const RecodeNode *> best_nodes; | |
| 427 std::vector<const RecodeNode *> best; | |
| 428 // Scan the segmented node chain for valid unichar ids. | |
| 429 for (auto &&entry : heaps) { | |
| 430 bool validChar = false; | |
| 431 int backcounter = 0; | |
| 432 const RecodeNode *node = &entry.data(); | |
| 433 while (node != nullptr && backcounter < backpath) { | |
| 434 if (node->code != null_char_ && | |
| 435 node->unichar_id != INVALID_UNICHAR_ID) { | |
| 436 validChar = true; | |
| 437 break; | |
| 438 } | |
| 439 node = node->prev; | |
| 440 ++backcounter; | |
| 441 } | |
| 442 if (validChar) { | |
| 443 best.push_back(&entry.data()); | |
| 444 } | |
| 445 } | |
| 446 // find the best rated segmented node chain and extract the unichar id. | |
| 447 if (!best.empty()) { | |
| 448 std::sort(best.begin(), best.end(), greater_than()); | |
| 449 ExtractPath(best[0], &best_nodes, backpath); | |
| 450 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, | |
| 451 &xcoords); | |
| 452 } | |
| 453 if (!unichar_ids.empty()) { | |
| 454 int bestPos = 0; | |
| 455 for (unsigned i = 1; i < unichar_ids.size(); ++i) { | |
| 456 if (ratings[i] < ratings[bestPos]) { | |
| 457 bestPos = i; | |
| 458 } | |
| 459 } | |
| 460 #if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60). | |
| 461 int bestCode = -10; | |
| 462 for (auto &node : best_nodes) { | |
| 463 if (node->unichar_id == unichar_ids[bestPos]) { | |
| 464 bestCode = node->code; | |
| 465 } | |
| 466 } | |
| 467 #endif | |
| 468 // Exclude the best choice for the followup decoding. | |
| 469 std::unordered_set<int> excludeCodeList; | |
| 470 for (auto &best_node : best_nodes) { | |
| 471 if (best_node->code != null_char_) { | |
| 472 excludeCodeList.insert(best_node->code); | |
| 473 } | |
| 474 } | |
| 475 if (j - 1 < excludedUnichars.size()) { | |
| 476 for (auto elem : excludeCodeList) { | |
| 477 excludedUnichars[j - 1].insert(elem); | |
| 478 } | |
| 479 } else { | |
| 480 excludedUnichars.push_back(excludeCodeList); | |
| 481 } | |
| 482 // Save the best choice for the choice iterator. | |
| 483 if (j - 1 < ctc_choices.size()) { | |
| 484 int id = unichar_ids[bestPos]; | |
| 485 const char *result = unicharset->id_to_unichar_ext(id); | |
| 486 float rating = ratings[bestPos]; | |
| 487 ctc_choices[j - 1].push_back( | |
| 488 std::pair<const char *, float>(result, rating)); | |
| 489 } else { | |
| 490 std::vector<std::pair<const char *, float>> choice; | |
| 491 int id = unichar_ids[bestPos]; | |
| 492 const char *result = unicharset->id_to_unichar_ext(id); | |
| 493 float rating = ratings[bestPos]; | |
| 494 choice.emplace_back(result, rating); | |
| 495 ctc_choices.push_back(choice); | |
| 496 } | |
| 497 // fill the blank spot with an empty array | |
| 498 } else { | |
| 499 if (j - 1 >= excludedUnichars.size()) { | |
| 500 std::unordered_set<int> excludeCodeList; | |
| 501 excludedUnichars.push_back(excludeCodeList); | |
| 502 } | |
| 503 if (j - 1 >= ctc_choices.size()) { | |
| 504 std::vector<std::pair<const char *, float>> choice; | |
| 505 ctc_choices.push_back(choice); | |
| 506 } | |
| 507 } | |
| 508 } | |
| 509 for (auto data : secondary_beam_) { | |
| 510 delete data; | |
| 511 } | |
| 512 secondary_beam_.clear(); | |
| 513 } | |
| 514 | |
| 515 // Generates debug output of the content of the beams after a Decode. | |
| 516 void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const { | |
| 517 for (int p = 0; p < beam_size_; ++p) { | |
| 518 for (int d = 0; d < 2; ++d) { | |
| 519 for (int c = 0; c < NC_COUNT; ++c) { | |
| 520 auto cont = static_cast<NodeContinuation>(c); | |
| 521 int index = BeamIndex(d, cont, 0); | |
| 522 if (beam_[p]->beams_[index].empty()) { | |
| 523 continue; | |
| 524 } | |
| 525 // Print all the best scoring nodes for each unichar found. | |
| 526 tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict", | |
| 527 kNodeContNames[c]); | |
| 528 DebugBeamPos(unicharset, beam_[p]->beams_[index]); | |
| 529 } | |
| 530 } | |
| 531 } | |
| 532 } | |
| 533 | |
| 534 // Generates debug output of the content of a single beam position. | |
| 535 void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset, | |
| 536 const RecodeHeap &heap) const { | |
| 537 std::vector<const RecodeNode *> unichar_bests(unicharset.size()); | |
| 538 const RecodeNode *null_best = nullptr; | |
| 539 int heap_size = heap.size(); | |
| 540 for (int i = 0; i < heap_size; ++i) { | |
| 541 const RecodeNode *node = &heap.get(i).data(); | |
| 542 if (node->unichar_id == INVALID_UNICHAR_ID) { | |
| 543 if (null_best == nullptr || null_best->score < node->score) { | |
| 544 null_best = node; | |
| 545 } | |
| 546 } else { | |
| 547 if (unichar_bests[node->unichar_id] == nullptr || | |
| 548 unichar_bests[node->unichar_id]->score < node->score) { | |
| 549 unichar_bests[node->unichar_id] = node; | |
| 550 } | |
| 551 } | |
| 552 } | |
| 553 for (auto &unichar_best : unichar_bests) { | |
| 554 if (unichar_best != nullptr) { | |
| 555 const RecodeNode &node = *unichar_best; | |
| 556 node.Print(null_char_, unicharset, 1); | |
| 557 } | |
| 558 } | |
| 559 if (null_best != nullptr) { | |
| 560 null_best->Print(null_char_, unicharset, 1); | |
| 561 } | |
| 562 } | |
| 563 | |
| 564 // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping | |
| 565 // duplicates, nulls and intermediate parts. | |
| 566 /* static */ | |
| 567 void RecodeBeamSearch::ExtractPathAsUnicharIds( | |
| 568 const std::vector<const RecodeNode *> &best_nodes, | |
| 569 std::vector<int> *unichar_ids, std::vector<float> *certs, | |
| 570 std::vector<float> *ratings, std::vector<int> *xcoords, | |
| 571 std::vector<int> *character_boundaries) { | |
| 572 unichar_ids->clear(); | |
| 573 certs->clear(); | |
| 574 ratings->clear(); | |
| 575 xcoords->clear(); | |
| 576 std::vector<int> starts; | |
| 577 std::vector<int> ends; | |
| 578 // Backtrack extracting only valid, non-duplicate unichar-ids. | |
| 579 int t = 0; | |
| 580 int width = best_nodes.size(); | |
| 581 while (t < width) { | |
| 582 double certainty = 0.0; | |
| 583 double rating = 0.0; | |
| 584 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) { | |
| 585 double cert = best_nodes[t++]->certainty; | |
| 586 if (cert < certainty) { | |
| 587 certainty = cert; | |
| 588 } | |
| 589 rating -= cert; | |
| 590 } | |
| 591 starts.push_back(t); | |
| 592 if (t < width) { | |
| 593 int unichar_id = best_nodes[t]->unichar_id; | |
| 594 if (unichar_id == UNICHAR_SPACE && !certs->empty() && | |
| 595 best_nodes[t]->permuter != NO_PERM) { | |
| 596 // All the rating and certainty go on the previous character except | |
| 597 // for the space itself. | |
| 598 if (certainty < certs->back()) { | |
| 599 certs->back() = certainty; | |
| 600 } | |
| 601 ratings->back() += rating; | |
| 602 certainty = 0.0; | |
| 603 rating = 0.0; | |
| 604 } | |
| 605 unichar_ids->push_back(unichar_id); | |
| 606 xcoords->push_back(t); | |
| 607 do { | |
| 608 double cert = best_nodes[t++]->certainty; | |
| 609 // Special-case NO-PERM space to forget the certainty of the previous | |
| 610 // nulls. See long comment in ContinueContext. | |
| 611 if (cert < certainty || (unichar_id == UNICHAR_SPACE && | |
| 612 best_nodes[t - 1]->permuter == NO_PERM)) { | |
| 613 certainty = cert; | |
| 614 } | |
| 615 rating -= cert; | |
| 616 } while (t < width && best_nodes[t]->duplicate); | |
| 617 ends.push_back(t); | |
| 618 certs->push_back(certainty); | |
| 619 ratings->push_back(rating); | |
| 620 } else if (!certs->empty()) { | |
| 621 if (certainty < certs->back()) { | |
| 622 certs->back() = certainty; | |
| 623 } | |
| 624 ratings->back() += rating; | |
| 625 } | |
| 626 } | |
| 627 starts.push_back(width); | |
| 628 if (character_boundaries != nullptr) { | |
| 629 calculateCharBoundaries(&starts, &ends, character_boundaries, width); | |
| 630 } | |
| 631 xcoords->push_back(width); | |
| 632 } | |
| 633 | |
| 634 // Sets up a word with the ratings matrix and fake blobs with boxes in the | |
| 635 // right places. | |
| 636 WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space, | |
| 637 const TBOX &line_box, int word_start, | |
| 638 int word_end, float space_certainty, | |
| 639 const UNICHARSET *unicharset, | |
| 640 const std::vector<int> &xcoords, | |
| 641 float scale_factor) { | |
| 642 // Make a fake blob for each non-zero label. | |
| 643 C_BLOB_LIST blobs; | |
| 644 C_BLOB_IT b_it(&blobs); | |
| 645 for (int i = word_start; i < word_end; ++i) { | |
| 646 if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) { | |
| 647 TBOX box(static_cast<int16_t>( | |
| 648 std::floor(character_boundaries_[i] * scale_factor)) + | |
| 649 line_box.left(), | |
| 650 line_box.bottom(), | |
| 651 static_cast<int16_t>( | |
| 652 std::ceil(character_boundaries_[i + 1] * scale_factor)) + | |
| 653 line_box.left(), | |
| 654 line_box.top()); | |
| 655 b_it.add_after_then_move(C_BLOB::FakeBlob(box)); | |
| 656 } | |
| 657 } | |
| 658 // Make a fake word from the blobs. | |
| 659 WERD *word = new WERD(&blobs, leading_space, nullptr); | |
| 660 // Make a WERD_RES from the word. | |
| 661 auto *word_res = new WERD_RES(word); | |
| 662 word_res->end = word_end - word_start + leading_space; | |
| 663 word_res->uch_set = unicharset; | |
| 664 word_res->combination = true; // Give it ownership of the word. | |
| 665 word_res->space_certainty = space_certainty; | |
| 666 word_res->ratings = new MATRIX(word_end - word_start, 1); | |
| 667 return word_res; | |
| 668 } | |
| 669 | |
| 670 // Fills top_n_flags_ with bools that are true iff the corresponding output | |
| 671 // is one of the top_n. | |
| 672 void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs, | |
| 673 int top_n) { | |
| 674 top_n_flags_.clear(); | |
| 675 top_n_flags_.resize(num_outputs, TN_ALSO_RAN); | |
| 676 top_code_ = -1; | |
| 677 second_code_ = -1; | |
| 678 top_heap_.clear(); | |
| 679 for (int i = 0; i < num_outputs; ++i) { | |
| 680 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) { | |
| 681 TopPair entry(outputs[i], i); | |
| 682 top_heap_.Push(&entry); | |
| 683 if (top_heap_.size() > top_n) { | |
| 684 top_heap_.Pop(&entry); | |
| 685 } | |
| 686 } | |
| 687 } | |
| 688 while (!top_heap_.empty()) { | |
| 689 TopPair entry; | |
| 690 top_heap_.Pop(&entry); | |
| 691 if (top_heap_.size() > 1) { | |
| 692 top_n_flags_[entry.data()] = TN_TOPN; | |
| 693 } else { | |
| 694 top_n_flags_[entry.data()] = TN_TOP2; | |
| 695 if (top_heap_.empty()) { | |
| 696 top_code_ = entry.data(); | |
| 697 } else { | |
| 698 second_code_ = entry.data(); | |
| 699 } | |
| 700 } | |
| 701 } | |
| 702 top_n_flags_[null_char_] = TN_TOP2; | |
| 703 } | |
| 704 | |
| 705 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList, | |
| 706 const float *outputs, int num_outputs, | |
| 707 int top_n) { | |
| 708 top_n_flags_.clear(); | |
| 709 top_n_flags_.resize(num_outputs, TN_ALSO_RAN); | |
| 710 top_code_ = -1; | |
| 711 second_code_ = -1; | |
| 712 top_heap_.clear(); | |
| 713 for (int i = 0; i < num_outputs; ++i) { | |
| 714 if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) && | |
| 715 !exList->count(i)) { | |
| 716 TopPair entry(outputs[i], i); | |
| 717 top_heap_.Push(&entry); | |
| 718 if (top_heap_.size() > top_n) { | |
| 719 top_heap_.Pop(&entry); | |
| 720 } | |
| 721 } | |
| 722 } | |
| 723 while (!top_heap_.empty()) { | |
| 724 TopPair entry; | |
| 725 top_heap_.Pop(&entry); | |
| 726 if (top_heap_.size() > 1) { | |
| 727 top_n_flags_[entry.data()] = TN_TOPN; | |
| 728 } else { | |
| 729 top_n_flags_[entry.data()] = TN_TOP2; | |
| 730 if (top_heap_.empty()) { | |
| 731 top_code_ = entry.data(); | |
| 732 } else { | |
| 733 second_code_ = entry.data(); | |
| 734 } | |
| 735 } | |
| 736 } | |
| 737 top_n_flags_[null_char_] = TN_TOP2; | |
| 738 } | |
| 739 | |
| 740 // Adds the computation for the current time-step to the beam. Call at each | |
| 741 // time-step in sequence from left to right. outputs is the activation vector | |
| 742 // for the current timestep. | |
| 743 void RecodeBeamSearch::DecodeStep(const float *outputs, int t, | |
| 744 double dict_ratio, double cert_offset, | |
| 745 double worst_dict_cert, | |
| 746 const UNICHARSET *charset, bool debug) { | |
| 747 if (t == static_cast<int>(beam_.size())) { | |
| 748 beam_.push_back(new RecodeBeam); | |
| 749 } | |
| 750 RecodeBeam *step = beam_[t]; | |
| 751 beam_size_ = t + 1; | |
| 752 step->Clear(); | |
| 753 if (t == 0) { | |
| 754 // The first step can only use singles and initials. | |
| 755 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2, | |
| 756 charset, dict_ratio, cert_offset, worst_dict_cert, step); | |
| 757 if (dict_ != nullptr) { | |
| 758 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, | |
| 759 TN_TOP2, charset, dict_ratio, cert_offset, | |
| 760 worst_dict_cert, step); | |
| 761 } | |
| 762 } else { | |
| 763 RecodeBeam *prev = beam_[t - 1]; | |
| 764 if (debug) { | |
| 765 int beam_index = BeamIndex(true, NC_ANYTHING, 0); | |
| 766 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { | |
| 767 std::vector<const RecodeNode *> path; | |
| 768 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); | |
| 769 tprintf("Step %d: Dawg beam %d:\n", t, i); | |
| 770 DebugPath(charset, path); | |
| 771 } | |
| 772 beam_index = BeamIndex(false, NC_ANYTHING, 0); | |
| 773 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { | |
| 774 std::vector<const RecodeNode *> path; | |
| 775 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); | |
| 776 tprintf("Step %d: Non-Dawg beam %d:\n", t, i); | |
| 777 DebugPath(charset, path); | |
| 778 } | |
| 779 } | |
| 780 int total_beam = 0; | |
| 781 // Work through the scores by group (top-2, top-n, the rest) while the beam | |
| 782 // is empty. This enables extending the context using only the top-n results | |
| 783 // first, which may have an empty intersection with the valid codes, so we | |
| 784 // fall back to the rest if the beam is empty. | |
| 785 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) { | |
| 786 auto top_n = static_cast<TopNState>(tn); | |
| 787 for (int index = 0; index < kNumBeams; ++index) { | |
| 788 // Working backwards through the heaps doesn't guarantee that we see the | |
| 789 // best first, but it comes before a lot of the worst, so it is slightly | |
| 790 // more efficient than going forwards. | |
| 791 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) { | |
| 792 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs, | |
| 793 top_n, charset, dict_ratio, cert_offset, | |
| 794 worst_dict_cert, step); | |
| 795 } | |
| 796 } | |
| 797 for (int index = 0; index < kNumBeams; ++index) { | |
| 798 if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) { | |
| 799 total_beam += step->beams_[index].size(); | |
| 800 } | |
| 801 } | |
| 802 } | |
| 803 // Special case for the best initial dawg. Push it on the heap if good | |
| 804 // enough, but there is only one, so it doesn't blow up the beam. | |
| 805 for (int c = 0; c < NC_COUNT; ++c) { | |
| 806 if (step->best_initial_dawgs_[c].code >= 0) { | |
| 807 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0); | |
| 808 RecodeHeap *dawg_heap = &step->beams_[index]; | |
| 809 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c], | |
| 810 dawg_heap); | |
| 811 } | |
| 812 } | |
| 813 } | |
| 814 } | |
| 815 | |
| 816 void RecodeBeamSearch::DecodeSecondaryStep( | |
| 817 const float *outputs, int t, double dict_ratio, double cert_offset, | |
| 818 double worst_dict_cert, const UNICHARSET *charset, bool debug) { | |
| 819 if (t == static_cast<int>(secondary_beam_.size())) { | |
| 820 secondary_beam_.push_back(new RecodeBeam); | |
| 821 } | |
| 822 RecodeBeam *step = secondary_beam_[t]; | |
| 823 step->Clear(); | |
| 824 if (t == 0) { | |
| 825 // The first step can only use singles and initials. | |
| 826 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2, | |
| 827 charset, dict_ratio, cert_offset, worst_dict_cert, step); | |
| 828 if (dict_ != nullptr) { | |
| 829 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, | |
| 830 TN_TOP2, charset, dict_ratio, cert_offset, | |
| 831 worst_dict_cert, step); | |
| 832 } | |
| 833 } else { | |
| 834 RecodeBeam *prev = secondary_beam_[t - 1]; | |
| 835 if (debug) { | |
| 836 int beam_index = BeamIndex(true, NC_ANYTHING, 0); | |
| 837 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { | |
| 838 std::vector<const RecodeNode *> path; | |
| 839 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); | |
| 840 tprintf("Step %d: Dawg beam %d:\n", t, i); | |
| 841 DebugPath(charset, path); | |
| 842 } | |
| 843 beam_index = BeamIndex(false, NC_ANYTHING, 0); | |
| 844 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) { | |
| 845 std::vector<const RecodeNode *> path; | |
| 846 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path); | |
| 847 tprintf("Step %d: Non-Dawg beam %d:\n", t, i); | |
| 848 DebugPath(charset, path); | |
| 849 } | |
| 850 } | |
| 851 int total_beam = 0; | |
| 852 // Work through the scores by group (top-2, top-n, the rest) while the beam | |
| 853 // is empty. This enables extending the context using only the top-n results | |
| 854 // first, which may have an empty intersection with the valid codes, so we | |
| 855 // fall back to the rest if the beam is empty. | |
| 856 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) { | |
| 857 auto top_n = static_cast<TopNState>(tn); | |
| 858 for (int index = 0; index < kNumBeams; ++index) { | |
| 859 // Working backwards through the heaps doesn't guarantee that we see the | |
| 860 // best first, but it comes before a lot of the worst, so it is slightly | |
| 861 // more efficient than going forwards. | |
| 862 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) { | |
| 863 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs, | |
| 864 top_n, charset, dict_ratio, cert_offset, | |
| 865 worst_dict_cert, step); | |
| 866 } | |
| 867 } | |
| 868 for (int index = 0; index < kNumBeams; ++index) { | |
| 869 if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) { | |
| 870 total_beam += step->beams_[index].size(); | |
| 871 } | |
| 872 } | |
| 873 } | |
| 874 // Special case for the best initial dawg. Push it on the heap if good | |
| 875 // enough, but there is only one, so it doesn't blow up the beam. | |
| 876 for (int c = 0; c < NC_COUNT; ++c) { | |
| 877 if (step->best_initial_dawgs_[c].code >= 0) { | |
| 878 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0); | |
| 879 RecodeHeap *dawg_heap = &step->beams_[index]; | |
| 880 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c], | |
| 881 dawg_heap); | |
| 882 } | |
| 883 } | |
| 884 } | |
| 885 } | |
| 886 | |
| 887 // Adds to the appropriate beams the legal (according to recoder) | |
| 888 // continuations of context prev, which is of the given length, using the | |
| 889 // given network outputs to provide scores to the choices. Uses only those | |
| 890 // choices for which top_n_flags[index] == top_n_flag. | |
| 891 void RecodeBeamSearch::ContinueContext( | |
| 892 const RecodeNode *prev, int index, const float *outputs, | |
| 893 TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio, | |
| 894 double cert_offset, double worst_dict_cert, RecodeBeam *step) { | |
| 895 RecodedCharID prefix; | |
| 896 RecodedCharID full_code; | |
| 897 const RecodeNode *previous = prev; | |
| 898 int length = LengthFromBeamsIndex(index); | |
| 899 bool use_dawgs = IsDawgFromBeamsIndex(index); | |
| 900 NodeContinuation prev_cont = ContinuationFromBeamsIndex(index); | |
| 901 for (int p = length - 1; p >= 0 && previous != nullptr; --p) { | |
| 902 while (previous->duplicate || previous->code == null_char_) { | |
| 903 previous = previous->prev; | |
| 904 } | |
| 905 prefix.Set(p, previous->code); | |
| 906 full_code.Set(p, previous->code); | |
| 907 previous = previous->prev; | |
| 908 } | |
| 909 if (prev != nullptr && !is_simple_text_) { | |
| 910 if (top_n_flags_[prev->code] == top_n_flag) { | |
| 911 if (prev_cont != NC_NO_DUP) { | |
| 912 float cert = | |
| 913 NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset; | |
| 914 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id, | |
| 915 cert, worst_dict_cert, dict_ratio, use_dawgs, | |
| 916 NC_ANYTHING, prev, step); | |
| 917 } | |
| 918 if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 && | |
| 919 prev->code != null_char_) { | |
| 920 float cert = NetworkIO::ProbToCertainty(outputs[prev->code] + | |
| 921 outputs[null_char_]) + | |
| 922 cert_offset; | |
| 923 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id, | |
| 924 cert, worst_dict_cert, dict_ratio, use_dawgs, | |
| 925 NC_NO_DUP, prev, step); | |
| 926 } | |
| 927 } | |
| 928 if (prev_cont == NC_ONLY_DUP) { | |
| 929 return; | |
| 930 } | |
| 931 if (prev->code != null_char_ && length > 0 && | |
| 932 top_n_flags_[null_char_] == top_n_flag) { | |
| 933 // Allow nulls within multi code sequences, as the nulls within are not | |
| 934 // explicitly included in the code sequence. | |
| 935 float cert = | |
| 936 NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset; | |
| 937 PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID, | |
| 938 cert, worst_dict_cert, dict_ratio, use_dawgs, | |
| 939 NC_ANYTHING, prev, step); | |
| 940 } | |
| 941 } | |
| 942 const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix); | |
| 943 if (final_codes != nullptr) { | |
| 944 for (int code : *final_codes) { | |
| 945 if (top_n_flags_[code] != top_n_flag) { | |
| 946 continue; | |
| 947 } | |
| 948 if (prev != nullptr && prev->code == code && !is_simple_text_) { | |
| 949 continue; | |
| 950 } | |
| 951 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; | |
| 952 if (cert < kMinCertainty && code != null_char_) { | |
| 953 continue; | |
| 954 } | |
| 955 full_code.Set(length, code); | |
| 956 int unichar_id = recoder_.DecodeUnichar(full_code); | |
| 957 // Map the null char to INVALID. | |
| 958 if (length == 0 && code == null_char_) { | |
| 959 unichar_id = INVALID_UNICHAR_ID; | |
| 960 } | |
| 961 if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr && | |
| 962 !charset->get_enabled(unichar_id)) { | |
| 963 continue; // disabled by whitelist/blacklist | |
| 964 } | |
| 965 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio, | |
| 966 use_dawgs, NC_ANYTHING, prev, step); | |
| 967 if (top_n_flag == TN_TOP2 && code != null_char_) { | |
| 968 float prob = outputs[code] + outputs[null_char_]; | |
| 969 if (prev != nullptr && prev_cont == NC_ANYTHING && | |
| 970 prev->code != null_char_ && | |
| 971 ((prev->code == top_code_ && code == second_code_) || | |
| 972 (code == top_code_ && prev->code == second_code_))) { | |
| 973 prob += outputs[prev->code]; | |
| 974 } | |
| 975 cert = NetworkIO::ProbToCertainty(prob) + cert_offset; | |
| 976 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio, | |
| 977 use_dawgs, NC_ONLY_DUP, prev, step); | |
| 978 } | |
| 979 } | |
| 980 } | |
| 981 const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix); | |
| 982 if (next_codes != nullptr) { | |
| 983 for (int code : *next_codes) { | |
| 984 if (top_n_flags_[code] != top_n_flag) { | |
| 985 continue; | |
| 986 } | |
| 987 if (prev != nullptr && prev->code == code && !is_simple_text_) { | |
| 988 continue; | |
| 989 } | |
| 990 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset; | |
| 991 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert, | |
| 992 worst_dict_cert, dict_ratio, use_dawgs, | |
| 993 NC_ANYTHING, prev, step); | |
| 994 if (top_n_flag == TN_TOP2 && code != null_char_) { | |
| 995 float prob = outputs[code] + outputs[null_char_]; | |
| 996 if (prev != nullptr && prev_cont == NC_ANYTHING && | |
| 997 prev->code != null_char_ && | |
| 998 ((prev->code == top_code_ && code == second_code_) || | |
| 999 (code == top_code_ && prev->code == second_code_))) { | |
| 1000 prob += outputs[prev->code]; | |
| 1001 } | |
| 1002 cert = NetworkIO::ProbToCertainty(prob) + cert_offset; | |
| 1003 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, | |
| 1004 cert, worst_dict_cert, dict_ratio, use_dawgs, | |
| 1005 NC_ONLY_DUP, prev, step); | |
| 1006 } | |
| 1007 } | |
| 1008 } | |
| 1009 } | |
| 1010 | |
| 1011 // Continues for a new unichar, using dawg or non-dawg as per flag. | |
| 1012 void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert, | |
| 1013 float worst_dict_cert, float dict_ratio, | |
| 1014 bool use_dawgs, NodeContinuation cont, | |
| 1015 const RecodeNode *prev, | |
| 1016 RecodeBeam *step) { | |
| 1017 if (use_dawgs) { | |
| 1018 if (cert > worst_dict_cert) { | |
| 1019 ContinueDawg(code, unichar_id, cert, cont, prev, step); | |
| 1020 } | |
| 1021 } else { | |
| 1022 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)]; | |
| 1023 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false, | |
| 1024 false, false, false, cert * dict_ratio, prev, nullptr, | |
| 1025 nodawg_heap); | |
| 1026 if (dict_ != nullptr && | |
| 1027 ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) || | |
| 1028 !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) { | |
| 1029 // Any top choice position that can start a new word, ie a space or | |
| 1030 // any non-space-delimited character, should also be considered | |
| 1031 // by the dawg search, so push initial dawg to the dawg heap. | |
| 1032 float dawg_cert = cert; | |
| 1033 PermuterType permuter = TOP_CHOICE_PERM; | |
| 1034 // Since we use the space either side of a dictionary word in the | |
| 1035 // certainty of the word, (to properly handle weak spaces) and the | |
| 1036 // space is coming from a non-dict word, we need special conditions | |
| 1037 // to avoid degrading the certainty of the dict word that follows. | |
| 1038 // With a space we don't multiply the certainty by dict_ratio, and we | |
| 1039 // flag the space with NO_PERM to indicate that we should not use the | |
| 1040 // predecessor nulls to generate the confidence for the space, as they | |
| 1041 // have already been multiplied by dict_ratio, and we can't go back to | |
| 1042 // insert more entries in any previous heaps. | |
| 1043 if (unichar_id == UNICHAR_SPACE) { | |
| 1044 permuter = NO_PERM; | |
| 1045 } else { | |
| 1046 dawg_cert *= dict_ratio; | |
| 1047 } | |
| 1048 PushInitialDawgIfBetter(code, unichar_id, permuter, false, false, | |
| 1049 dawg_cert, cont, prev, step); | |
| 1050 } | |
| 1051 } | |
| 1052 } | |
| 1053 | |
| 1054 // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev, | |
| 1055 // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id | |
| 1056 // is a valid continuation of whatever is in prev. | |
| 1057 void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert, | |
| 1058 NodeContinuation cont, | |
| 1059 const RecodeNode *prev, RecodeBeam *step) { | |
| 1060 RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)]; | |
| 1061 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)]; | |
| 1062 if (unichar_id == INVALID_UNICHAR_ID) { | |
| 1063 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false, | |
| 1064 false, false, cert, prev, nullptr, dawg_heap); | |
| 1065 return; | |
| 1066 } | |
| 1067 // Avoid dictionary probe if score a total loss. | |
| 1068 float score = cert; | |
| 1069 if (prev != nullptr) { | |
| 1070 score += prev->score; | |
| 1071 } | |
| 1072 if (dawg_heap->size() >= kBeamWidths[0] && | |
| 1073 score <= dawg_heap->PeekTop().data().score && | |
| 1074 nodawg_heap->size() >= kBeamWidths[0] && | |
| 1075 score <= nodawg_heap->PeekTop().data().score) { | |
| 1076 return; | |
| 1077 } | |
| 1078 const RecodeNode *uni_prev = prev; | |
| 1079 // Prev may be a partial code, null_char, or duplicate, so scan back to the | |
| 1080 // last valid unichar_id. | |
| 1081 while (uni_prev != nullptr && | |
| 1082 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) { | |
| 1083 uni_prev = uni_prev->prev; | |
| 1084 } | |
| 1085 if (unichar_id == UNICHAR_SPACE) { | |
| 1086 if (uni_prev != nullptr && uni_prev->end_of_word) { | |
| 1087 // Space is good. Push initial state, to the dawg beam and a regular | |
| 1088 // space to the top choice beam. | |
| 1089 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false, | |
| 1090 false, cert, cont, prev, step); | |
| 1091 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter, | |
| 1092 false, false, false, false, cert, prev, nullptr, | |
| 1093 nodawg_heap); | |
| 1094 } | |
| 1095 return; | |
| 1096 } else if (uni_prev != nullptr && uni_prev->start_of_dawg && | |
| 1097 uni_prev->unichar_id != UNICHAR_SPACE && | |
| 1098 dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) && | |
| 1099 dict_->getUnicharset().IsSpaceDelimited(unichar_id)) { | |
| 1100 return; // Can't break words between space delimited chars. | |
| 1101 } | |
| 1102 DawgPositionVector initial_dawgs; | |
| 1103 auto *updated_dawgs = new DawgPositionVector; | |
| 1104 DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM); | |
| 1105 bool word_start = false; | |
| 1106 if (uni_prev == nullptr) { | |
| 1107 // Starting from beginning of line. | |
| 1108 dict_->default_dawgs(&initial_dawgs, false); | |
| 1109 word_start = true; | |
| 1110 } else if (uni_prev->dawgs != nullptr) { | |
| 1111 // Continuing a previous dict word. | |
| 1112 dawg_args.active_dawgs = uni_prev->dawgs; | |
| 1113 word_start = uni_prev->start_of_dawg; | |
| 1114 } else { | |
| 1115 return; // Can't continue if not a dict word. | |
| 1116 } | |
| 1117 auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay( | |
| 1118 &dawg_args, dict_->getUnicharset(), unichar_id, false)); | |
| 1119 if (permuter != NO_PERM) { | |
| 1120 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false, | |
| 1121 word_start, dawg_args.valid_end, false, cert, prev, | |
| 1122 dawg_args.updated_dawgs, dawg_heap); | |
| 1123 if (dawg_args.valid_end && !space_delimited_) { | |
| 1124 // We can start another word right away, so push initial state as well, | |
| 1125 // to the dawg beam, and the regular character to the top choice beam, | |
| 1126 // since non-dict words can start here too. | |
| 1127 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true, | |
| 1128 cert, cont, prev, step); | |
| 1129 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false, | |
| 1130 word_start, true, false, cert, prev, nullptr, | |
| 1131 nodawg_heap); | |
| 1132 } | |
| 1133 } else { | |
| 1134 delete updated_dawgs; | |
| 1135 } | |
| 1136 } | |
| 1137 | |
| 1138 // Adds a RecodeNode composed of the tuple (code, unichar_id, | |
| 1139 // initial-dawg-state, prev, cert) to the given heap if/ there is room or if | |
| 1140 // better than the current worst element if already full. | |
| 1141 void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id, | |
| 1142 PermuterType permuter, | |
| 1143 bool start, bool end, float cert, | |
| 1144 NodeContinuation cont, | |
| 1145 const RecodeNode *prev, | |
| 1146 RecodeBeam *step) { | |
| 1147 RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont]; | |
| 1148 float score = cert; | |
| 1149 if (prev != nullptr) { | |
| 1150 score += prev->score; | |
| 1151 } | |
| 1152 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) { | |
| 1153 auto *initial_dawgs = new DawgPositionVector; | |
| 1154 dict_->default_dawgs(initial_dawgs, false); | |
| 1155 RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert, | |
| 1156 score, prev, initial_dawgs, | |
| 1157 ComputeCodeHash(code, false, prev)); | |
| 1158 *best_initial_dawg = node; | |
| 1159 } | |
| 1160 } | |
| 1161 | |
| 1162 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, | |
| 1163 // false, false, false, false, cert, prev, nullptr) to heap if there is room | |
| 1164 // or if better than the current worst element if already full. | |
| 1165 /* static */ | |
| 1166 void RecodeBeamSearch::PushDupOrNoDawgIfBetter( | |
| 1167 int length, bool dup, int code, int unichar_id, float cert, | |
| 1168 float worst_dict_cert, float dict_ratio, bool use_dawgs, | |
| 1169 NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) { | |
| 1170 int index = BeamIndex(use_dawgs, cont, length); | |
| 1171 if (use_dawgs) { | |
| 1172 if (cert > worst_dict_cert) { | |
| 1173 PushHeapIfBetter(kBeamWidths[length], code, unichar_id, | |
| 1174 prev ? prev->permuter : NO_PERM, false, false, false, | |
| 1175 dup, cert, prev, nullptr, &step->beams_[index]); | |
| 1176 } | |
| 1177 } else { | |
| 1178 cert *= dict_ratio; | |
| 1179 if (cert >= kMinCertainty || code == null_char_) { | |
| 1180 PushHeapIfBetter(kBeamWidths[length], code, unichar_id, | |
| 1181 prev ? prev->permuter : TOP_CHOICE_PERM, false, false, | |
| 1182 false, dup, cert, prev, nullptr, &step->beams_[index]); | |
| 1183 } | |
| 1184 } | |
| 1185 } | |
| 1186 | |
| 1187 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter, | |
| 1188 // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room | |
| 1189 // or if better than the current worst element if already full. | |
| 1190 void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id, | |
| 1191 PermuterType permuter, bool dawg_start, | |
| 1192 bool word_start, bool end, bool dup, | |
| 1193 float cert, const RecodeNode *prev, | |
| 1194 DawgPositionVector *d, | |
| 1195 RecodeHeap *heap) { | |
| 1196 float score = cert; | |
| 1197 if (prev != nullptr) { | |
| 1198 score += prev->score; | |
| 1199 } | |
| 1200 if (heap->size() < max_size || score > heap->PeekTop().data().score) { | |
| 1201 uint64_t hash = ComputeCodeHash(code, dup, prev); | |
| 1202 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end, | |
| 1203 dup, cert, score, prev, d, hash); | |
| 1204 if (UpdateHeapIfMatched(&node, heap)) { | |
| 1205 return; | |
| 1206 } | |
| 1207 RecodePair entry(score, node); | |
| 1208 heap->Push(&entry); | |
| 1209 ASSERT_HOST(entry.data().dawgs == nullptr); | |
| 1210 if (heap->size() > max_size) { | |
| 1211 heap->Pop(&entry); | |
| 1212 } | |
| 1213 } else { | |
| 1214 delete d; | |
| 1215 } | |
| 1216 } | |
| 1217 | |
| 1218 // Adds a RecodeNode to heap if there is room | |
| 1219 // or if better than the current worst element if already full. | |
| 1220 void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node, | |
| 1221 RecodeHeap *heap) { | |
| 1222 if (heap->size() < max_size || node->score > heap->PeekTop().data().score) { | |
| 1223 if (UpdateHeapIfMatched(node, heap)) { | |
| 1224 return; | |
| 1225 } | |
| 1226 RecodePair entry(node->score, *node); | |
| 1227 heap->Push(&entry); | |
| 1228 ASSERT_HOST(entry.data().dawgs == nullptr); | |
| 1229 if (heap->size() > max_size) { | |
| 1230 heap->Pop(&entry); | |
| 1231 } | |
| 1232 } | |
| 1233 } | |
| 1234 | |
| 1235 // Searches the heap for a matching entry, and updates the score with | |
| 1236 // reshuffle if needed. Returns true if there was a match. | |
| 1237 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node, | |
| 1238 RecodeHeap *heap) { | |
| 1239 // TODO(rays) consider hash map instead of linear search. | |
| 1240 // It might not be faster because the hash map would have to be updated | |
| 1241 // every time a heap reshuffle happens, and that would be a lot of overhead. | |
| 1242 std::vector<RecodePair> &nodes = heap->heap(); | |
| 1243 for (auto &i : nodes) { | |
| 1244 RecodeNode &node = i.data(); | |
| 1245 if (node.code == new_node->code && node.code_hash == new_node->code_hash && | |
| 1246 node.permuter == new_node->permuter && | |
| 1247 node.start_of_dawg == new_node->start_of_dawg) { | |
| 1248 if (new_node->score > node.score) { | |
| 1249 // The new one is better. Update the entire node in the heap and | |
| 1250 // reshuffle. | |
| 1251 node = *new_node; | |
| 1252 i.key() = node.score; | |
| 1253 heap->Reshuffle(&i); | |
| 1254 } | |
| 1255 return true; | |
| 1256 } | |
| 1257 } | |
| 1258 return false; | |
| 1259 } | |
| 1260 | |
| 1261 // Computes and returns the code-hash for the given code and prev. | |
| 1262 uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup, | |
| 1263 const RecodeNode *prev) const { | |
| 1264 uint64_t hash = prev == nullptr ? 0 : prev->code_hash; | |
| 1265 if (!dup && code != null_char_) { | |
| 1266 int num_classes = recoder_.code_range(); | |
| 1267 uint64_t carry = (((hash >> 32) * num_classes) >> 32); | |
| 1268 hash *= num_classes; | |
| 1269 hash += carry; | |
| 1270 hash += code; | |
| 1271 } | |
| 1272 return hash; | |
| 1273 } | |
| 1274 | |
| 1275 // Backtracks to extract the best path through the lattice that was built | |
| 1276 // during Decode. On return the best_nodes vector essentially contains the set | |
| 1277 // of code, score pairs that make the optimal path with the constraint that | |
| 1278 // the recoder can decode the code sequence back to a sequence of unichar-ids. | |
| 1279 void RecodeBeamSearch::ExtractBestPaths( | |
| 1280 std::vector<const RecodeNode *> *best_nodes, | |
| 1281 std::vector<const RecodeNode *> *second_nodes) const { | |
| 1282 // Scan both beams to extract the best and second best paths. | |
| 1283 const RecodeNode *best_node = nullptr; | |
| 1284 const RecodeNode *second_best_node = nullptr; | |
| 1285 const RecodeBeam *last_beam = beam_[beam_size_ - 1]; | |
| 1286 for (int c = 0; c < NC_COUNT; ++c) { | |
| 1287 if (c == NC_ONLY_DUP) { | |
| 1288 continue; | |
| 1289 } | |
| 1290 auto cont = static_cast<NodeContinuation>(c); | |
| 1291 for (int is_dawg = 0; is_dawg < 2; ++is_dawg) { | |
| 1292 int beam_index = BeamIndex(is_dawg, cont, 0); | |
| 1293 int heap_size = last_beam->beams_[beam_index].size(); | |
| 1294 for (int h = 0; h < heap_size; ++h) { | |
| 1295 const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data(); | |
| 1296 if (is_dawg) { | |
| 1297 // dawg_node may be a null_char, or duplicate, so scan back to the | |
| 1298 // last valid unichar_id. | |
| 1299 const RecodeNode *dawg_node = node; | |
| 1300 while (dawg_node != nullptr && | |
| 1301 (dawg_node->unichar_id == INVALID_UNICHAR_ID || | |
| 1302 dawg_node->duplicate)) { | |
| 1303 dawg_node = dawg_node->prev; | |
| 1304 } | |
| 1305 if (dawg_node == nullptr || | |
| 1306 (!dawg_node->end_of_word && | |
| 1307 dawg_node->unichar_id != UNICHAR_SPACE)) { | |
| 1308 // Dawg node is not valid. | |
| 1309 continue; | |
| 1310 } | |
| 1311 } | |
| 1312 if (best_node == nullptr || node->score > best_node->score) { | |
| 1313 second_best_node = best_node; | |
| 1314 best_node = node; | |
| 1315 } else if (second_best_node == nullptr || | |
| 1316 node->score > second_best_node->score) { | |
| 1317 second_best_node = node; | |
| 1318 } | |
| 1319 } | |
| 1320 } | |
| 1321 } | |
| 1322 if (second_nodes != nullptr) { | |
| 1323 ExtractPath(second_best_node, second_nodes); | |
| 1324 } | |
| 1325 ExtractPath(best_node, best_nodes); | |
| 1326 } | |
| 1327 | |
| 1328 // Helper backtracks through the lattice from the given node, storing the | |
| 1329 // path and reversing it. | |
| 1330 void RecodeBeamSearch::ExtractPath( | |
| 1331 const RecodeNode *node, std::vector<const RecodeNode *> *path) const { | |
| 1332 path->clear(); | |
| 1333 while (node != nullptr) { | |
| 1334 path->push_back(node); | |
| 1335 node = node->prev; | |
| 1336 } | |
| 1337 std::reverse(path->begin(), path->end()); | |
| 1338 } | |
| 1339 | |
| 1340 void RecodeBeamSearch::ExtractPath(const RecodeNode *node, | |
| 1341 std::vector<const RecodeNode *> *path, | |
| 1342 int limiter) const { | |
| 1343 int pathcounter = 0; | |
| 1344 path->clear(); | |
| 1345 while (node != nullptr && pathcounter < limiter) { | |
| 1346 path->push_back(node); | |
| 1347 node = node->prev; | |
| 1348 ++pathcounter; | |
| 1349 } | |
| 1350 std::reverse(path->begin(), path->end()); | |
| 1351 } | |
| 1352 | |
| 1353 // Helper prints debug information on the given lattice path. | |
| 1354 void RecodeBeamSearch::DebugPath( | |
| 1355 const UNICHARSET *unicharset, | |
| 1356 const std::vector<const RecodeNode *> &path) const { | |
| 1357 for (unsigned c = 0; c < path.size(); ++c) { | |
| 1358 const RecodeNode &node = *path[c]; | |
| 1359 tprintf("%u ", c); | |
| 1360 node.Print(null_char_, *unicharset, 1); | |
| 1361 } | |
| 1362 } | |
| 1363 | |
| 1364 // Helper prints debug information on the given unichar path. | |
| 1365 void RecodeBeamSearch::DebugUnicharPath( | |
| 1366 const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path, | |
| 1367 const std::vector<int> &unichar_ids, const std::vector<float> &certs, | |
| 1368 const std::vector<float> &ratings, const std::vector<int> &xcoords) const { | |
| 1369 auto num_ids = unichar_ids.size(); | |
| 1370 double total_rating = 0.0; | |
| 1371 for (unsigned c = 0; c < num_ids; ++c) { | |
| 1372 int coord = xcoords[c]; | |
| 1373 tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c], | |
| 1374 unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c], | |
| 1375 path[coord]->start_of_word, path[coord]->end_of_word, | |
| 1376 path[coord]->permuter); | |
| 1377 total_rating += ratings[c]; | |
| 1378 } | |
| 1379 tprintf("Path total rating = %g\n", total_rating); | |
| 1380 } | |
| 1381 | |
| 1382 } // namespace tesseract. |
