Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/wordrec/language_model.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: language_model.cpp | |
| 3 // Description: Functions that utilize the knowledge about the properties, | |
| 4 // structure and statistics of the language to help recognition. | |
| 5 // Author: Daria Antonova | |
| 6 // | |
| 7 // (C) Copyright 2009, Google Inc. | |
| 8 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 9 // you may not use this file except in compliance with the License. | |
| 10 // You may obtain a copy of the License at | |
| 11 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 12 // Unless required by applicable law or agreed to in writing, software | |
| 13 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 15 // See the License for the specific language governing permissions and | |
| 16 // limitations under the License. | |
| 17 // | |
| 18 /////////////////////////////////////////////////////////////////////// | |
| 19 | |
| 20 #include "language_model.h" | |
| 21 #include <tesseract/unichar.h> // for UNICHAR_ID, INVALID_UNICHAR_ID | |
| 22 #include <cassert> // for assert | |
| 23 #include <cmath> // for log2, pow | |
| 24 #include "blamer.h" // for BlamerBundle | |
| 25 #include "ccutil.h" // for CCUtil | |
| 26 #include "dawg.h" // for NO_EDGE, Dawg, Dawg::kPatternUn... | |
| 27 #include "errcode.h" // for ASSERT_HOST | |
| 28 #include "lm_state.h" // for ViterbiStateEntry, ViterbiState... | |
| 29 #include "matrix.h" // for MATRIX_COORD | |
| 30 #include "pageres.h" // for WERD_RES | |
| 31 #include "params.h" // for IntParam, BoolParam, DoubleParam | |
| 32 #include "params_training_featdef.h" // for ParamsTrainingHypothesis, PTRAI... | |
| 33 #include "tprintf.h" // for tprintf | |
| 34 #include "unicharset.h" // for UNICHARSET | |
| 35 #include "unicity_table.h" // for UnicityTable | |
| 36 | |
| 37 template <typename T> | |
| 38 class UnicityTable; | |
| 39 | |
| 40 namespace tesseract { | |
| 41 | |
| 42 class LMPainPoints; | |
| 43 struct FontInfo; | |
| 44 | |
| 45 #if defined(ANDROID) | |
| 46 static inline double log2(double n) { | |
| 47 return log(n) / log(2.0); | |
| 48 } | |
| 49 #endif // ANDROID | |
| 50 | |
| 51 const float LanguageModel::kMaxAvgNgramCost = 25.0f; | |
| 52 | |
| 53 LanguageModel::LanguageModel(const UnicityTable<FontInfo> *fontinfo_table, Dict *dict) | |
| 54 : INT_MEMBER(language_model_debug_level, 0, "Language model debug level", | |
| 55 dict->getCCUtil()->params()) | |
| 56 , BOOL_INIT_MEMBER(language_model_ngram_on, false, | |
| 57 "Turn on/off the use of character ngram model", dict->getCCUtil()->params()) | |
| 58 , INT_MEMBER(language_model_ngram_order, 8, "Maximum order of the character ngram model", | |
| 59 dict->getCCUtil()->params()) | |
| 60 , INT_MEMBER(language_model_viterbi_list_max_num_prunable, 10, | |
| 61 "Maximum number of prunable (those for which" | |
| 62 " PrunablePath() is true) entries in each viterbi list" | |
| 63 " recorded in BLOB_CHOICEs", | |
| 64 dict->getCCUtil()->params()) | |
| 65 , INT_MEMBER(language_model_viterbi_list_max_size, 500, | |
| 66 "Maximum size of viterbi lists recorded in BLOB_CHOICEs", | |
| 67 dict->getCCUtil()->params()) | |
| 68 , double_MEMBER(language_model_ngram_small_prob, 0.000001, | |
| 69 "To avoid overly small denominators use this as the " | |
| 70 "floor of the probability returned by the ngram model.", | |
| 71 dict->getCCUtil()->params()) | |
| 72 , double_MEMBER(language_model_ngram_nonmatch_score, -40.0, | |
| 73 "Average classifier score of a non-matching unichar.", | |
| 74 dict->getCCUtil()->params()) | |
| 75 , BOOL_MEMBER(language_model_ngram_use_only_first_uft8_step, false, | |
| 76 "Use only the first UTF8 step of the given string" | |
| 77 " when computing log probabilities.", | |
| 78 dict->getCCUtil()->params()) | |
| 79 , double_MEMBER(language_model_ngram_scale_factor, 0.03, | |
| 80 "Strength of the character ngram model relative to the" | |
| 81 " character classifier ", | |
| 82 dict->getCCUtil()->params()) | |
| 83 , double_MEMBER(language_model_ngram_rating_factor, 16.0, | |
| 84 "Factor to bring log-probs into the same range as ratings" | |
| 85 " when multiplied by outline length ", | |
| 86 dict->getCCUtil()->params()) | |
| 87 , BOOL_MEMBER(language_model_ngram_space_delimited_language, true, | |
| 88 "Words are delimited by space", dict->getCCUtil()->params()) | |
| 89 , INT_MEMBER(language_model_min_compound_length, 3, "Minimum length of compound words", | |
| 90 dict->getCCUtil()->params()) | |
| 91 , double_MEMBER(language_model_penalty_non_freq_dict_word, 0.1, | |
| 92 "Penalty for words not in the frequent word dictionary", | |
| 93 dict->getCCUtil()->params()) | |
| 94 , double_MEMBER(language_model_penalty_non_dict_word, 0.15, "Penalty for non-dictionary words", | |
| 95 dict->getCCUtil()->params()) | |
| 96 , double_MEMBER(language_model_penalty_punc, 0.2, "Penalty for inconsistent punctuation", | |
| 97 dict->getCCUtil()->params()) | |
| 98 , double_MEMBER(language_model_penalty_case, 0.1, "Penalty for inconsistent case", | |
| 99 dict->getCCUtil()->params()) | |
| 100 , double_MEMBER(language_model_penalty_script, 0.5, "Penalty for inconsistent script", | |
| 101 dict->getCCUtil()->params()) | |
| 102 , double_MEMBER(language_model_penalty_chartype, 0.3, "Penalty for inconsistent character type", | |
| 103 dict->getCCUtil()->params()) | |
| 104 , | |
| 105 // TODO(daria, rays): enable font consistency checking | |
| 106 // after improving font analysis. | |
| 107 double_MEMBER(language_model_penalty_font, 0.00, "Penalty for inconsistent font", | |
| 108 dict->getCCUtil()->params()) | |
| 109 , double_MEMBER(language_model_penalty_spacing, 0.05, "Penalty for inconsistent spacing", | |
| 110 dict->getCCUtil()->params()) | |
| 111 , double_MEMBER(language_model_penalty_increment, 0.01, "Penalty increment", | |
| 112 dict->getCCUtil()->params()) | |
| 113 , INT_MEMBER(wordrec_display_segmentations, 0, "Display Segmentations (ScrollView)", | |
| 114 dict->getCCUtil()->params()) | |
| 115 , BOOL_INIT_MEMBER(language_model_use_sigmoidal_certainty, false, | |
| 116 "Use sigmoidal score for certainty", dict->getCCUtil()->params()) | |
| 117 , dawg_args_(nullptr, new DawgPositionVector(), NO_PERM) | |
| 118 , fontinfo_table_(fontinfo_table) | |
| 119 , dict_(dict) { | |
| 120 ASSERT_HOST(dict_ != nullptr); | |
| 121 } | |
| 122 | |
| 123 LanguageModel::~LanguageModel() { | |
| 124 delete dawg_args_.updated_dawgs; | |
| 125 } | |
| 126 | |
| 127 void LanguageModel::InitForWord(const WERD_CHOICE *prev_word, bool fixed_pitch, | |
| 128 float max_char_wh_ratio, float rating_cert_scale) { | |
| 129 fixed_pitch_ = fixed_pitch; | |
| 130 max_char_wh_ratio_ = max_char_wh_ratio; | |
| 131 rating_cert_scale_ = rating_cert_scale; | |
| 132 acceptable_choice_found_ = false; | |
| 133 correct_segmentation_explored_ = false; | |
| 134 | |
| 135 // Initialize vectors with beginning DawgInfos. | |
| 136 very_beginning_active_dawgs_.clear(); | |
| 137 dict_->init_active_dawgs(&very_beginning_active_dawgs_, false); | |
| 138 beginning_active_dawgs_.clear(); | |
| 139 dict_->default_dawgs(&beginning_active_dawgs_, false); | |
| 140 | |
| 141 // Fill prev_word_str_ with the last language_model_ngram_order | |
| 142 // unichars from prev_word. | |
| 143 if (language_model_ngram_on) { | |
| 144 if (prev_word != nullptr && !prev_word->unichar_string().empty()) { | |
| 145 prev_word_str_ = prev_word->unichar_string(); | |
| 146 if (language_model_ngram_space_delimited_language) { | |
| 147 prev_word_str_ += ' '; | |
| 148 } | |
| 149 } else { | |
| 150 prev_word_str_ = " "; | |
| 151 } | |
| 152 const char *str_ptr = prev_word_str_.c_str(); | |
| 153 const char *str_end = str_ptr + prev_word_str_.length(); | |
| 154 int step; | |
| 155 prev_word_unichar_step_len_ = 0; | |
| 156 while (str_ptr != str_end && (step = UNICHAR::utf8_step(str_ptr))) { | |
| 157 str_ptr += step; | |
| 158 ++prev_word_unichar_step_len_; | |
| 159 } | |
| 160 ASSERT_HOST(str_ptr == str_end); | |
| 161 } | |
| 162 } | |
| 163 | |
| 164 /** | |
| 165 * Helper scans the collection of predecessors for competing siblings that | |
| 166 * have the same letter with the opposite case, setting competing_vse. | |
| 167 */ | |
| 168 static void ScanParentsForCaseMix(const UNICHARSET &unicharset, LanguageModelState *parent_node) { | |
| 169 if (parent_node == nullptr) { | |
| 170 return; | |
| 171 } | |
| 172 ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); | |
| 173 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { | |
| 174 ViterbiStateEntry *vse = vit.data(); | |
| 175 vse->competing_vse = nullptr; | |
| 176 UNICHAR_ID unichar_id = vse->curr_b->unichar_id(); | |
| 177 if (unicharset.get_isupper(unichar_id) || unicharset.get_islower(unichar_id)) { | |
| 178 UNICHAR_ID other_case = unicharset.get_other_case(unichar_id); | |
| 179 if (other_case == unichar_id) { | |
| 180 continue; // Not in unicharset. | |
| 181 } | |
| 182 // Find other case in same list. There could be multiple entries with | |
| 183 // the same unichar_id, but in theory, they should all point to the | |
| 184 // same BLOB_CHOICE, and that is what we will be using to decide | |
| 185 // which to keep. | |
| 186 ViterbiStateEntry_IT vit2(&parent_node->viterbi_state_entries); | |
| 187 for (vit2.mark_cycle_pt(); | |
| 188 !vit2.cycled_list() && vit2.data()->curr_b->unichar_id() != other_case; vit2.forward()) { | |
| 189 } | |
| 190 if (!vit2.cycled_list()) { | |
| 191 vse->competing_vse = vit2.data(); | |
| 192 } | |
| 193 } | |
| 194 } | |
| 195 } | |
| 196 | |
| 197 /** | |
| 198 * Helper returns true if the given choice has a better case variant before | |
| 199 * it in the choice_list that is not distinguishable by size. | |
| 200 */ | |
| 201 static bool HasBetterCaseVariant(const UNICHARSET &unicharset, const BLOB_CHOICE *choice, | |
| 202 BLOB_CHOICE_LIST *choices) { | |
| 203 UNICHAR_ID choice_id = choice->unichar_id(); | |
| 204 UNICHAR_ID other_case = unicharset.get_other_case(choice_id); | |
| 205 if (other_case == choice_id || other_case == INVALID_UNICHAR_ID) { | |
| 206 return false; // Not upper or lower or not in unicharset. | |
| 207 } | |
| 208 if (unicharset.SizesDistinct(choice_id, other_case)) { | |
| 209 return false; // Can be separated by size. | |
| 210 } | |
| 211 BLOB_CHOICE_IT bc_it(choices); | |
| 212 for (bc_it.mark_cycle_pt(); !bc_it.cycled_list(); bc_it.forward()) { | |
| 213 BLOB_CHOICE *better_choice = bc_it.data(); | |
| 214 if (better_choice->unichar_id() == other_case) { | |
| 215 return true; // Found an earlier instance of other_case. | |
| 216 } else if (better_choice == choice) { | |
| 217 return false; // Reached the original choice. | |
| 218 } | |
| 219 } | |
| 220 return false; // Should never happen, but just in case. | |
| 221 } | |
| 222 | |
| 223 /** | |
| 224 * UpdateState has the job of combining the ViterbiStateEntry lists on each | |
| 225 * of the choices on parent_list with each of the blob choices in curr_list, | |
| 226 * making a new ViterbiStateEntry for each sensible path. | |
| 227 * | |
| 228 * This could be a huge set of combinations, creating a lot of work only to | |
| 229 * be truncated by some beam limit, but only certain kinds of paths will | |
| 230 * continue at the next step: | |
| 231 * - paths that are liked by the language model: either a DAWG or the n-gram | |
| 232 * model, where active. | |
| 233 * - paths that represent some kind of top choice. The old permuter permuted | |
| 234 * the top raw classifier score, the top upper case word and the top lower- | |
| 235 * case word. UpdateState now concentrates its top-choice paths on top | |
| 236 * lower-case, top upper-case (or caseless alpha), and top digit sequence, | |
| 237 * with allowance for continuation of these paths through blobs where such | |
| 238 * a character does not appear in the choices list. | |
| 239 * | |
| 240 * GetNextParentVSE enforces some of these models to minimize the number of | |
| 241 * calls to AddViterbiStateEntry, even prior to looking at the language model. | |
| 242 * Thus an n-blob sequence of [l1I] will produce 3n calls to | |
| 243 * AddViterbiStateEntry instead of 3^n. | |
| 244 * | |
| 245 * Of course it isn't quite that simple as Title Case is handled by allowing | |
| 246 * lower case to continue an upper case initial, but it has to be detected | |
| 247 * in the combiner so it knows which upper case letters are initial alphas. | |
| 248 */ | |
| 249 bool LanguageModel::UpdateState(bool just_classified, int curr_col, int curr_row, | |
| 250 BLOB_CHOICE_LIST *curr_list, LanguageModelState *parent_node, | |
| 251 LMPainPoints *pain_points, WERD_RES *word_res, | |
| 252 BestChoiceBundle *best_choice_bundle, BlamerBundle *blamer_bundle) { | |
| 253 if (language_model_debug_level > 0) { | |
| 254 tprintf("\nUpdateState: col=%d row=%d %s", curr_col, curr_row, | |
| 255 just_classified ? "just_classified" : ""); | |
| 256 if (language_model_debug_level > 5) { | |
| 257 tprintf("(parent=%p)\n", static_cast<void *>(parent_node)); | |
| 258 } else { | |
| 259 tprintf("\n"); | |
| 260 } | |
| 261 } | |
| 262 // Initialize helper variables. | |
| 263 bool word_end = (curr_row + 1 >= word_res->ratings->dimension()); | |
| 264 bool new_changed = false; | |
| 265 float denom = (language_model_ngram_on) ? ComputeDenom(curr_list) : 1.0f; | |
| 266 const UNICHARSET &unicharset = dict_->getUnicharset(); | |
| 267 BLOB_CHOICE *first_lower = nullptr; | |
| 268 BLOB_CHOICE *first_upper = nullptr; | |
| 269 BLOB_CHOICE *first_digit = nullptr; | |
| 270 bool has_alnum_mix = false; | |
| 271 if (parent_node != nullptr) { | |
| 272 int result = SetTopParentLowerUpperDigit(parent_node); | |
| 273 if (result < 0) { | |
| 274 if (language_model_debug_level > 0) { | |
| 275 tprintf("No parents found to process\n"); | |
| 276 } | |
| 277 return false; | |
| 278 } | |
| 279 if (result > 0) { | |
| 280 has_alnum_mix = true; | |
| 281 } | |
| 282 } | |
| 283 if (!GetTopLowerUpperDigit(curr_list, &first_lower, &first_upper, &first_digit)) { | |
| 284 has_alnum_mix = false; | |
| 285 }; | |
| 286 ScanParentsForCaseMix(unicharset, parent_node); | |
| 287 if (language_model_debug_level > 3 && parent_node != nullptr) { | |
| 288 parent_node->Print("Parent viterbi list"); | |
| 289 } | |
| 290 LanguageModelState *curr_state = best_choice_bundle->beam[curr_row]; | |
| 291 | |
| 292 // Call AddViterbiStateEntry() for each parent+child ViterbiStateEntry. | |
| 293 ViterbiStateEntry_IT vit; | |
| 294 BLOB_CHOICE_IT c_it(curr_list); | |
| 295 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { | |
| 296 BLOB_CHOICE *choice = c_it.data(); | |
| 297 // TODO(antonova): make sure commenting this out if ok for ngram | |
| 298 // model scoring (I think this was introduced to fix ngram model quirks). | |
| 299 // Skip nullptr unichars unless it is the only choice. | |
| 300 // if (!curr_list->singleton() && c_it.data()->unichar_id() == 0) continue; | |
| 301 UNICHAR_ID unichar_id = choice->unichar_id(); | |
| 302 if (unicharset.get_fragment(unichar_id)) { | |
| 303 continue; // Skip fragments. | |
| 304 } | |
| 305 // Set top choice flags. | |
| 306 LanguageModelFlagsType blob_choice_flags = kXhtConsistentFlag; | |
| 307 if (c_it.at_first() || !new_changed) { | |
| 308 blob_choice_flags |= kSmallestRatingFlag; | |
| 309 } | |
| 310 if (first_lower == choice) { | |
| 311 blob_choice_flags |= kLowerCaseFlag; | |
| 312 } | |
| 313 if (first_upper == choice) { | |
| 314 blob_choice_flags |= kUpperCaseFlag; | |
| 315 } | |
| 316 if (first_digit == choice) { | |
| 317 blob_choice_flags |= kDigitFlag; | |
| 318 } | |
| 319 | |
| 320 if (parent_node == nullptr) { | |
| 321 // Process the beginning of a word. | |
| 322 // If there is a better case variant that is not distinguished by size, | |
| 323 // skip this blob choice, as we have no choice but to accept the result | |
| 324 // of the character classifier to distinguish between them, even if | |
| 325 // followed by an upper case. | |
| 326 // With words like iPoc, and other CamelBackWords, the lower-upper | |
| 327 // transition can only be achieved if the classifier has the correct case | |
| 328 // as the top choice, and leaving an initial I lower down the list | |
| 329 // increases the chances of choosing IPoc simply because it doesn't | |
| 330 // include such a transition. iPoc will beat iPOC and ipoc because | |
| 331 // the other words are baseline/x-height inconsistent. | |
| 332 if (HasBetterCaseVariant(unicharset, choice, curr_list)) { | |
| 333 continue; | |
| 334 } | |
| 335 // Upper counts as lower at the beginning of a word. | |
| 336 if (blob_choice_flags & kUpperCaseFlag) { | |
| 337 blob_choice_flags |= kLowerCaseFlag; | |
| 338 } | |
| 339 new_changed |= AddViterbiStateEntry(blob_choice_flags, denom, word_end, curr_col, curr_row, | |
| 340 choice, curr_state, nullptr, pain_points, word_res, | |
| 341 best_choice_bundle, blamer_bundle); | |
| 342 } else { | |
| 343 // Get viterbi entries from each parent ViterbiStateEntry. | |
| 344 vit.set_to_list(&parent_node->viterbi_state_entries); | |
| 345 int vit_counter = 0; | |
| 346 vit.mark_cycle_pt(); | |
| 347 ViterbiStateEntry *parent_vse = nullptr; | |
| 348 LanguageModelFlagsType top_choice_flags; | |
| 349 while ((parent_vse = | |
| 350 GetNextParentVSE(just_classified, has_alnum_mix, c_it.data(), blob_choice_flags, | |
| 351 unicharset, word_res, &vit, &top_choice_flags)) != nullptr) { | |
| 352 // Skip pruned entries and do not look at prunable entries if already | |
| 353 // examined language_model_viterbi_list_max_num_prunable of those. | |
| 354 if (PrunablePath(*parent_vse) && | |
| 355 (++vit_counter > language_model_viterbi_list_max_num_prunable || | |
| 356 (language_model_ngram_on && parent_vse->ngram_info->pruned))) { | |
| 357 continue; | |
| 358 } | |
| 359 // If the parent has no alnum choice, (ie choice is the first in a | |
| 360 // string of alnum), and there is a better case variant that is not | |
| 361 // distinguished by size, skip this blob choice/parent, as with the | |
| 362 // initial blob treatment above. | |
| 363 if (!parent_vse->HasAlnumChoice(unicharset) && | |
| 364 HasBetterCaseVariant(unicharset, choice, curr_list)) { | |
| 365 continue; | |
| 366 } | |
| 367 // Create a new ViterbiStateEntry if BLOB_CHOICE in c_it.data() | |
| 368 // looks good according to the Dawgs or character ngram model. | |
| 369 new_changed |= AddViterbiStateEntry(top_choice_flags, denom, word_end, curr_col, curr_row, | |
| 370 c_it.data(), curr_state, parent_vse, pain_points, | |
| 371 word_res, best_choice_bundle, blamer_bundle); | |
| 372 } | |
| 373 } | |
| 374 } | |
| 375 return new_changed; | |
| 376 } | |
| 377 | |
| 378 /** | |
| 379 * Finds the first lower and upper case letter and first digit in curr_list. | |
| 380 * For non-upper/lower languages, alpha counts as upper. | |
| 381 * Uses the first character in the list in place of empty results. | |
| 382 * Returns true if both alpha and digits are found. | |
| 383 */ | |
| 384 bool LanguageModel::GetTopLowerUpperDigit(BLOB_CHOICE_LIST *curr_list, BLOB_CHOICE **first_lower, | |
| 385 BLOB_CHOICE **first_upper, | |
| 386 BLOB_CHOICE **first_digit) const { | |
| 387 BLOB_CHOICE_IT c_it(curr_list); | |
| 388 const UNICHARSET &unicharset = dict_->getUnicharset(); | |
| 389 BLOB_CHOICE *first_unichar = nullptr; | |
| 390 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { | |
| 391 UNICHAR_ID unichar_id = c_it.data()->unichar_id(); | |
| 392 if (unicharset.get_fragment(unichar_id)) { | |
| 393 continue; // skip fragments | |
| 394 } | |
| 395 if (first_unichar == nullptr) { | |
| 396 first_unichar = c_it.data(); | |
| 397 } | |
| 398 if (*first_lower == nullptr && unicharset.get_islower(unichar_id)) { | |
| 399 *first_lower = c_it.data(); | |
| 400 } | |
| 401 if (*first_upper == nullptr && unicharset.get_isalpha(unichar_id) && | |
| 402 !unicharset.get_islower(unichar_id)) { | |
| 403 *first_upper = c_it.data(); | |
| 404 } | |
| 405 if (*first_digit == nullptr && unicharset.get_isdigit(unichar_id)) { | |
| 406 *first_digit = c_it.data(); | |
| 407 } | |
| 408 } | |
| 409 ASSERT_HOST(first_unichar != nullptr); | |
| 410 bool mixed = (*first_lower != nullptr || *first_upper != nullptr) && *first_digit != nullptr; | |
| 411 if (*first_lower == nullptr) { | |
| 412 *first_lower = first_unichar; | |
| 413 } | |
| 414 if (*first_upper == nullptr) { | |
| 415 *first_upper = first_unichar; | |
| 416 } | |
| 417 if (*first_digit == nullptr) { | |
| 418 *first_digit = first_unichar; | |
| 419 } | |
| 420 return mixed; | |
| 421 } | |
| 422 | |
| 423 /** | |
| 424 * Forces there to be at least one entry in the overall set of the | |
| 425 * viterbi_state_entries of each element of parent_node that has the | |
| 426 * top_choice_flag set for lower, upper and digit using the same rules as | |
| 427 * GetTopLowerUpperDigit, setting the flag on the first found suitable | |
| 428 * candidate, whether or not the flag is set on some other parent. | |
| 429 * Returns 1 if both alpha and digits are found among the parents, -1 if no | |
| 430 * parents are found at all (a legitimate case), and 0 otherwise. | |
| 431 */ | |
| 432 int LanguageModel::SetTopParentLowerUpperDigit(LanguageModelState *parent_node) const { | |
| 433 if (parent_node == nullptr) { | |
| 434 return -1; | |
| 435 } | |
| 436 UNICHAR_ID top_id = INVALID_UNICHAR_ID; | |
| 437 ViterbiStateEntry *top_lower = nullptr; | |
| 438 ViterbiStateEntry *top_upper = nullptr; | |
| 439 ViterbiStateEntry *top_digit = nullptr; | |
| 440 ViterbiStateEntry *top_choice = nullptr; | |
| 441 float lower_rating = 0.0f; | |
| 442 float upper_rating = 0.0f; | |
| 443 float digit_rating = 0.0f; | |
| 444 float top_rating = 0.0f; | |
| 445 const UNICHARSET &unicharset = dict_->getUnicharset(); | |
| 446 ViterbiStateEntry_IT vit(&parent_node->viterbi_state_entries); | |
| 447 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { | |
| 448 ViterbiStateEntry *vse = vit.data(); | |
| 449 // INVALID_UNICHAR_ID should be treated like a zero-width joiner, so scan | |
| 450 // back to the real character if needed. | |
| 451 ViterbiStateEntry *unichar_vse = vse; | |
| 452 UNICHAR_ID unichar_id = unichar_vse->curr_b->unichar_id(); | |
| 453 float rating = unichar_vse->curr_b->rating(); | |
| 454 while (unichar_id == INVALID_UNICHAR_ID && unichar_vse->parent_vse != nullptr) { | |
| 455 unichar_vse = unichar_vse->parent_vse; | |
| 456 unichar_id = unichar_vse->curr_b->unichar_id(); | |
| 457 rating = unichar_vse->curr_b->rating(); | |
| 458 } | |
| 459 if (unichar_id != INVALID_UNICHAR_ID) { | |
| 460 if (unicharset.get_islower(unichar_id)) { | |
| 461 if (top_lower == nullptr || lower_rating > rating) { | |
| 462 top_lower = vse; | |
| 463 lower_rating = rating; | |
| 464 } | |
| 465 } else if (unicharset.get_isalpha(unichar_id)) { | |
| 466 if (top_upper == nullptr || upper_rating > rating) { | |
| 467 top_upper = vse; | |
| 468 upper_rating = rating; | |
| 469 } | |
| 470 } else if (unicharset.get_isdigit(unichar_id)) { | |
| 471 if (top_digit == nullptr || digit_rating > rating) { | |
| 472 top_digit = vse; | |
| 473 digit_rating = rating; | |
| 474 } | |
| 475 } | |
| 476 } | |
| 477 if (top_choice == nullptr || top_rating > rating) { | |
| 478 top_choice = vse; | |
| 479 top_rating = rating; | |
| 480 top_id = unichar_id; | |
| 481 } | |
| 482 } | |
| 483 if (top_choice == nullptr) { | |
| 484 return -1; | |
| 485 } | |
| 486 bool mixed = (top_lower != nullptr || top_upper != nullptr) && top_digit != nullptr; | |
| 487 if (top_lower == nullptr) { | |
| 488 top_lower = top_choice; | |
| 489 } | |
| 490 top_lower->top_choice_flags |= kLowerCaseFlag; | |
| 491 if (top_upper == nullptr) { | |
| 492 top_upper = top_choice; | |
| 493 } | |
| 494 top_upper->top_choice_flags |= kUpperCaseFlag; | |
| 495 if (top_digit == nullptr) { | |
| 496 top_digit = top_choice; | |
| 497 } | |
| 498 top_digit->top_choice_flags |= kDigitFlag; | |
| 499 top_choice->top_choice_flags |= kSmallestRatingFlag; | |
| 500 if (top_id != INVALID_UNICHAR_ID && dict_->compound_marker(top_id) && | |
| 501 (top_choice->top_choice_flags & (kLowerCaseFlag | kUpperCaseFlag | kDigitFlag))) { | |
| 502 // If the compound marker top choice carries any of the top alnum flags, | |
| 503 // then give it all of them, allowing words like I-295 to be chosen. | |
| 504 top_choice->top_choice_flags |= kLowerCaseFlag | kUpperCaseFlag | kDigitFlag; | |
| 505 } | |
| 506 return mixed ? 1 : 0; | |
| 507 } | |
| 508 | |
| 509 /** | |
| 510 * Finds the next ViterbiStateEntry with which the given unichar_id can | |
| 511 * combine sensibly, taking into account any mixed alnum/mixed case | |
| 512 * situation, and whether this combination has been inspected before. | |
| 513 */ | |
| 514 ViterbiStateEntry *LanguageModel::GetNextParentVSE(bool just_classified, bool mixed_alnum, | |
| 515 const BLOB_CHOICE *bc, | |
| 516 LanguageModelFlagsType blob_choice_flags, | |
| 517 const UNICHARSET &unicharset, WERD_RES *word_res, | |
| 518 ViterbiStateEntry_IT *vse_it, | |
| 519 LanguageModelFlagsType *top_choice_flags) const { | |
| 520 for (; !vse_it->cycled_list(); vse_it->forward()) { | |
| 521 ViterbiStateEntry *parent_vse = vse_it->data(); | |
| 522 // Only consider the parent if it has been updated or | |
| 523 // if the current ratings cell has just been classified. | |
| 524 if (!just_classified && !parent_vse->updated) { | |
| 525 continue; | |
| 526 } | |
| 527 if (language_model_debug_level > 2) { | |
| 528 parent_vse->Print("Considering"); | |
| 529 } | |
| 530 // If the parent is non-alnum, then upper counts as lower. | |
| 531 *top_choice_flags = blob_choice_flags; | |
| 532 if ((blob_choice_flags & kUpperCaseFlag) && !parent_vse->HasAlnumChoice(unicharset)) { | |
| 533 *top_choice_flags |= kLowerCaseFlag; | |
| 534 } | |
| 535 *top_choice_flags &= parent_vse->top_choice_flags; | |
| 536 UNICHAR_ID unichar_id = bc->unichar_id(); | |
| 537 const BLOB_CHOICE *parent_b = parent_vse->curr_b; | |
| 538 UNICHAR_ID parent_id = parent_b->unichar_id(); | |
| 539 // Digits do not bind to alphas if there is a mix in both parent and current | |
| 540 // or if the alpha is not the top choice. | |
| 541 if (unicharset.get_isdigit(unichar_id) && unicharset.get_isalpha(parent_id) && | |
| 542 (mixed_alnum || *top_choice_flags == 0)) { | |
| 543 continue; // Digits don't bind to alphas. | |
| 544 } | |
| 545 // Likewise alphas do not bind to digits if there is a mix in both or if | |
| 546 // the digit is not the top choice. | |
| 547 if (unicharset.get_isalpha(unichar_id) && unicharset.get_isdigit(parent_id) && | |
| 548 (mixed_alnum || *top_choice_flags == 0)) { | |
| 549 continue; // Alphas don't bind to digits. | |
| 550 } | |
| 551 // If there is a case mix of the same alpha in the parent list, then | |
| 552 // competing_vse is non-null and will be used to determine whether | |
| 553 // or not to bind the current blob choice. | |
| 554 if (parent_vse->competing_vse != nullptr) { | |
| 555 const BLOB_CHOICE *competing_b = parent_vse->competing_vse->curr_b; | |
| 556 UNICHAR_ID other_id = competing_b->unichar_id(); | |
| 557 if (language_model_debug_level >= 5) { | |
| 558 tprintf("Parent %s has competition %s\n", unicharset.id_to_unichar(parent_id), | |
| 559 unicharset.id_to_unichar(other_id)); | |
| 560 } | |
| 561 if (unicharset.SizesDistinct(parent_id, other_id)) { | |
| 562 // If other_id matches bc wrt position and size, and parent_id, doesn't, | |
| 563 // don't bind to the current parent. | |
| 564 if (bc->PosAndSizeAgree(*competing_b, word_res->x_height, | |
| 565 language_model_debug_level >= 5) && | |
| 566 !bc->PosAndSizeAgree(*parent_b, word_res->x_height, language_model_debug_level >= 5)) { | |
| 567 continue; // Competing blobchoice has a better vertical match. | |
| 568 } | |
| 569 } | |
| 570 } | |
| 571 vse_it->forward(); | |
| 572 return parent_vse; // This one is good! | |
| 573 } | |
| 574 return nullptr; // Ran out of possibilities. | |
| 575 } | |
| 576 | |
| 577 bool LanguageModel::AddViterbiStateEntry(LanguageModelFlagsType top_choice_flags, float denom, | |
| 578 bool word_end, int curr_col, int curr_row, BLOB_CHOICE *b, | |
| 579 LanguageModelState *curr_state, | |
| 580 ViterbiStateEntry *parent_vse, LMPainPoints *pain_points, | |
| 581 WERD_RES *word_res, BestChoiceBundle *best_choice_bundle, | |
| 582 BlamerBundle *blamer_bundle) { | |
| 583 ViterbiStateEntry_IT vit; | |
| 584 if (language_model_debug_level > 1) { | |
| 585 tprintf( | |
| 586 "AddViterbiStateEntry for unichar %s rating=%.4f" | |
| 587 " certainty=%.4f top_choice_flags=0x%x", | |
| 588 dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->rating(), b->certainty(), | |
| 589 top_choice_flags); | |
| 590 if (language_model_debug_level > 5) { | |
| 591 tprintf(" parent_vse=%p\n", static_cast<void *>(parent_vse)); | |
| 592 } else { | |
| 593 tprintf("\n"); | |
| 594 } | |
| 595 } | |
| 596 ASSERT_HOST(curr_state != nullptr); | |
| 597 // Check whether the list is full. | |
| 598 if (curr_state->viterbi_state_entries_length >= language_model_viterbi_list_max_size) { | |
| 599 if (language_model_debug_level > 1) { | |
| 600 tprintf("AddViterbiStateEntry: viterbi list is full!\n"); | |
| 601 } | |
| 602 return false; | |
| 603 } | |
| 604 | |
| 605 // Invoke Dawg language model component. | |
| 606 LanguageModelDawgInfo *dawg_info = GenerateDawgInfo(word_end, curr_col, curr_row, *b, parent_vse); | |
| 607 | |
| 608 float outline_length = AssociateUtils::ComputeOutlineLength(rating_cert_scale_, *b); | |
| 609 // Invoke Ngram language model component. | |
| 610 LanguageModelNgramInfo *ngram_info = nullptr; | |
| 611 if (language_model_ngram_on) { | |
| 612 ngram_info = | |
| 613 GenerateNgramInfo(dict_->getUnicharset().id_to_unichar(b->unichar_id()), b->certainty(), | |
| 614 denom, curr_col, curr_row, outline_length, parent_vse); | |
| 615 ASSERT_HOST(ngram_info != nullptr); | |
| 616 } | |
| 617 bool liked_by_language_model = | |
| 618 dawg_info != nullptr || (ngram_info != nullptr && !ngram_info->pruned); | |
| 619 // Quick escape if not liked by the language model, can't be consistent | |
| 620 // xheight, and not top choice. | |
| 621 if (!liked_by_language_model && top_choice_flags == 0) { | |
| 622 if (language_model_debug_level > 1) { | |
| 623 tprintf("Language model components very early pruned this entry\n"); | |
| 624 } | |
| 625 delete ngram_info; | |
| 626 delete dawg_info; | |
| 627 return false; | |
| 628 } | |
| 629 | |
| 630 // Check consistency of the path and set the relevant consistency_info. | |
| 631 LMConsistencyInfo consistency_info(parent_vse != nullptr ? &parent_vse->consistency_info | |
| 632 : nullptr); | |
| 633 // Start with just the x-height consistency, as it provides significant | |
| 634 // pruning opportunity. | |
| 635 consistency_info.ComputeXheightConsistency( | |
| 636 b, dict_->getUnicharset().get_ispunctuation(b->unichar_id())); | |
| 637 // Turn off xheight consistent flag if not consistent. | |
| 638 if (consistency_info.InconsistentXHeight()) { | |
| 639 top_choice_flags &= ~kXhtConsistentFlag; | |
| 640 } | |
| 641 | |
| 642 // Quick escape if not liked by the language model, not consistent xheight, | |
| 643 // and not top choice. | |
| 644 if (!liked_by_language_model && top_choice_flags == 0) { | |
| 645 if (language_model_debug_level > 1) { | |
| 646 tprintf("Language model components early pruned this entry\n"); | |
| 647 } | |
| 648 delete ngram_info; | |
| 649 delete dawg_info; | |
| 650 return false; | |
| 651 } | |
| 652 | |
| 653 // Compute the rest of the consistency info. | |
| 654 FillConsistencyInfo(curr_col, word_end, b, parent_vse, word_res, &consistency_info); | |
| 655 if (dawg_info != nullptr && consistency_info.invalid_punc) { | |
| 656 consistency_info.invalid_punc = false; // do not penalize dict words | |
| 657 } | |
| 658 | |
| 659 // Compute cost of associating the blobs that represent the current unichar. | |
| 660 AssociateStats associate_stats; | |
| 661 ComputeAssociateStats(curr_col, curr_row, max_char_wh_ratio_, parent_vse, word_res, | |
| 662 &associate_stats); | |
| 663 if (parent_vse != nullptr) { | |
| 664 associate_stats.shape_cost += parent_vse->associate_stats.shape_cost; | |
| 665 associate_stats.bad_shape |= parent_vse->associate_stats.bad_shape; | |
| 666 } | |
| 667 | |
| 668 // Create the new ViterbiStateEntry compute the adjusted cost of the path. | |
| 669 auto *new_vse = new ViterbiStateEntry(parent_vse, b, 0.0, outline_length, consistency_info, | |
| 670 associate_stats, top_choice_flags, dawg_info, ngram_info, | |
| 671 (language_model_debug_level > 0) | |
| 672 ? dict_->getUnicharset().id_to_unichar(b->unichar_id()) | |
| 673 : nullptr); | |
| 674 new_vse->cost = ComputeAdjustedPathCost(new_vse); | |
| 675 if (language_model_debug_level >= 3) { | |
| 676 tprintf("Adjusted cost = %g\n", new_vse->cost); | |
| 677 } | |
| 678 | |
| 679 // Invoke Top Choice language model component to make the final adjustments | |
| 680 // to new_vse->top_choice_flags. | |
| 681 if (!curr_state->viterbi_state_entries.empty() && new_vse->top_choice_flags) { | |
| 682 GenerateTopChoiceInfo(new_vse, parent_vse, curr_state); | |
| 683 } | |
| 684 | |
| 685 // If language model components did not like this unichar - return. | |
| 686 bool keep = new_vse->top_choice_flags || liked_by_language_model; | |
| 687 if (!(top_choice_flags & kSmallestRatingFlag) && // no non-top choice paths | |
| 688 consistency_info.inconsistent_script) { // with inconsistent script | |
| 689 keep = false; | |
| 690 } | |
| 691 if (!keep) { | |
| 692 if (language_model_debug_level > 1) { | |
| 693 tprintf("Language model components did not like this entry\n"); | |
| 694 } | |
| 695 delete new_vse; | |
| 696 return false; | |
| 697 } | |
| 698 | |
| 699 // Discard this entry if it represents a prunable path and | |
| 700 // language_model_viterbi_list_max_num_prunable such entries with a lower | |
| 701 // cost have already been recorded. | |
| 702 if (PrunablePath(*new_vse) && | |
| 703 (curr_state->viterbi_state_entries_prunable_length >= | |
| 704 language_model_viterbi_list_max_num_prunable) && | |
| 705 new_vse->cost >= curr_state->viterbi_state_entries_prunable_max_cost) { | |
| 706 if (language_model_debug_level > 1) { | |
| 707 tprintf("Discarded ViterbiEntry with high cost %g max cost %g\n", new_vse->cost, | |
| 708 curr_state->viterbi_state_entries_prunable_max_cost); | |
| 709 } | |
| 710 delete new_vse; | |
| 711 return false; | |
| 712 } | |
| 713 | |
| 714 // Update best choice if needed. | |
| 715 if (word_end) { | |
| 716 UpdateBestChoice(new_vse, pain_points, word_res, best_choice_bundle, blamer_bundle); | |
| 717 // Discard the entry if UpdateBestChoice() found flaws in it. | |
| 718 if (new_vse->cost >= WERD_CHOICE::kBadRating && new_vse != best_choice_bundle->best_vse) { | |
| 719 if (language_model_debug_level > 1) { | |
| 720 tprintf("Discarded ViterbiEntry with high cost %g\n", new_vse->cost); | |
| 721 } | |
| 722 delete new_vse; | |
| 723 return false; | |
| 724 } | |
| 725 } | |
| 726 | |
| 727 // Add the new ViterbiStateEntry and to curr_state->viterbi_state_entries. | |
| 728 curr_state->viterbi_state_entries.add_sorted(ViterbiStateEntry::Compare, false, new_vse); | |
| 729 curr_state->viterbi_state_entries_length++; | |
| 730 if (PrunablePath(*new_vse)) { | |
| 731 curr_state->viterbi_state_entries_prunable_length++; | |
| 732 } | |
| 733 | |
| 734 // Update lms->viterbi_state_entries_prunable_max_cost and clear | |
| 735 // top_choice_flags of entries with ratings_sum than new_vse->ratings_sum. | |
| 736 if ((curr_state->viterbi_state_entries_prunable_length >= | |
| 737 language_model_viterbi_list_max_num_prunable) || | |
| 738 new_vse->top_choice_flags) { | |
| 739 ASSERT_HOST(!curr_state->viterbi_state_entries.empty()); | |
| 740 int prunable_counter = language_model_viterbi_list_max_num_prunable; | |
| 741 vit.set_to_list(&(curr_state->viterbi_state_entries)); | |
| 742 for (vit.mark_cycle_pt(); !vit.cycled_list(); vit.forward()) { | |
| 743 ViterbiStateEntry *curr_vse = vit.data(); | |
| 744 // Clear the appropriate top choice flags of the entries in the | |
| 745 // list that have cost higher thank new_entry->cost | |
| 746 // (since they will not be top choices any more). | |
| 747 if (curr_vse->top_choice_flags && curr_vse != new_vse && curr_vse->cost > new_vse->cost) { | |
| 748 curr_vse->top_choice_flags &= ~(new_vse->top_choice_flags); | |
| 749 } | |
| 750 if (prunable_counter > 0 && PrunablePath(*curr_vse)) { | |
| 751 --prunable_counter; | |
| 752 } | |
| 753 // Update curr_state->viterbi_state_entries_prunable_max_cost. | |
| 754 if (prunable_counter == 0) { | |
| 755 curr_state->viterbi_state_entries_prunable_max_cost = vit.data()->cost; | |
| 756 if (language_model_debug_level > 1) { | |
| 757 tprintf("Set viterbi_state_entries_prunable_max_cost to %g\n", | |
| 758 curr_state->viterbi_state_entries_prunable_max_cost); | |
| 759 } | |
| 760 prunable_counter = -1; // stop counting | |
| 761 } | |
| 762 } | |
| 763 } | |
| 764 | |
| 765 // Print the newly created ViterbiStateEntry. | |
| 766 if (language_model_debug_level > 2) { | |
| 767 new_vse->Print("New"); | |
| 768 if (language_model_debug_level > 5) { | |
| 769 curr_state->Print("Updated viterbi list"); | |
| 770 } | |
| 771 } | |
| 772 | |
| 773 return true; | |
| 774 } | |
| 775 | |
| 776 void LanguageModel::GenerateTopChoiceInfo(ViterbiStateEntry *new_vse, | |
| 777 const ViterbiStateEntry *parent_vse, | |
| 778 LanguageModelState *lms) { | |
| 779 ViterbiStateEntry_IT vit(&(lms->viterbi_state_entries)); | |
| 780 for (vit.mark_cycle_pt(); | |
| 781 !vit.cycled_list() && new_vse->top_choice_flags && new_vse->cost >= vit.data()->cost; | |
| 782 vit.forward()) { | |
| 783 // Clear the appropriate flags if the list already contains | |
| 784 // a top choice entry with a lower cost. | |
| 785 new_vse->top_choice_flags &= ~(vit.data()->top_choice_flags); | |
| 786 } | |
| 787 if (language_model_debug_level > 2) { | |
| 788 tprintf("GenerateTopChoiceInfo: top_choice_flags=0x%x\n", new_vse->top_choice_flags); | |
| 789 } | |
| 790 } | |
| 791 | |
| 792 LanguageModelDawgInfo *LanguageModel::GenerateDawgInfo(bool word_end, int curr_col, int curr_row, | |
| 793 const BLOB_CHOICE &b, | |
| 794 const ViterbiStateEntry *parent_vse) { | |
| 795 // Initialize active_dawgs from parent_vse if it is not nullptr. | |
| 796 // Otherwise use very_beginning_active_dawgs_. | |
| 797 if (parent_vse == nullptr) { | |
| 798 dawg_args_.active_dawgs = &very_beginning_active_dawgs_; | |
| 799 dawg_args_.permuter = NO_PERM; | |
| 800 } else { | |
| 801 if (parent_vse->dawg_info == nullptr) { | |
| 802 return nullptr; // not a dict word path | |
| 803 } | |
| 804 dawg_args_.active_dawgs = &parent_vse->dawg_info->active_dawgs; | |
| 805 dawg_args_.permuter = parent_vse->dawg_info->permuter; | |
| 806 } | |
| 807 | |
| 808 // Deal with hyphenated words. | |
| 809 if (word_end && dict_->has_hyphen_end(&dict_->getUnicharset(), b.unichar_id(), curr_col == 0)) { | |
| 810 if (language_model_debug_level > 0) { | |
| 811 tprintf("Hyphenated word found\n"); | |
| 812 } | |
| 813 return new LanguageModelDawgInfo(dawg_args_.active_dawgs, COMPOUND_PERM); | |
| 814 } | |
| 815 | |
| 816 // Deal with compound words. | |
| 817 if (dict_->compound_marker(b.unichar_id()) && | |
| 818 (parent_vse == nullptr || parent_vse->dawg_info->permuter != NUMBER_PERM)) { | |
| 819 if (language_model_debug_level > 0) { | |
| 820 tprintf("Found compound marker\n"); | |
| 821 } | |
| 822 // Do not allow compound operators at the beginning and end of the word. | |
| 823 // Do not allow more than one compound operator per word. | |
| 824 // Do not allow compounding of words with lengths shorter than | |
| 825 // language_model_min_compound_length | |
| 826 if (parent_vse == nullptr || word_end || dawg_args_.permuter == COMPOUND_PERM || | |
| 827 parent_vse->length < language_model_min_compound_length) { | |
| 828 return nullptr; | |
| 829 } | |
| 830 | |
| 831 // Check that the path terminated before the current character is a word. | |
| 832 bool has_word_ending = false; | |
| 833 for (unsigned i = 0; i < parent_vse->dawg_info->active_dawgs.size(); ++i) { | |
| 834 const DawgPosition &pos = parent_vse->dawg_info->active_dawgs[i]; | |
| 835 const Dawg *pdawg = pos.dawg_index < 0 ? nullptr : dict_->GetDawg(pos.dawg_index); | |
| 836 if (pdawg == nullptr || pos.back_to_punc) { | |
| 837 continue; | |
| 838 }; | |
| 839 if (pdawg->type() == DAWG_TYPE_WORD && pos.dawg_ref != NO_EDGE && | |
| 840 pdawg->end_of_word(pos.dawg_ref)) { | |
| 841 has_word_ending = true; | |
| 842 break; | |
| 843 } | |
| 844 } | |
| 845 if (!has_word_ending) { | |
| 846 return nullptr; | |
| 847 } | |
| 848 | |
| 849 if (language_model_debug_level > 0) { | |
| 850 tprintf("Compound word found\n"); | |
| 851 } | |
| 852 return new LanguageModelDawgInfo(&beginning_active_dawgs_, COMPOUND_PERM); | |
| 853 } // done dealing with compound words | |
| 854 | |
| 855 LanguageModelDawgInfo *dawg_info = nullptr; | |
| 856 | |
| 857 // Call LetterIsOkay(). | |
| 858 // Use the normalized IDs so that all shapes of ' can be allowed in words | |
| 859 // like don't. | |
| 860 const auto &normed_ids = dict_->getUnicharset().normed_ids(b.unichar_id()); | |
| 861 DawgPositionVector tmp_active_dawgs; | |
| 862 for (unsigned i = 0; i < normed_ids.size(); ++i) { | |
| 863 if (language_model_debug_level > 2) { | |
| 864 tprintf("Test Letter OK for unichar %d, normed %d\n", b.unichar_id(), normed_ids[i]); | |
| 865 } | |
| 866 dict_->LetterIsOkay(&dawg_args_, dict_->getUnicharset(), normed_ids[i], | |
| 867 word_end && i == normed_ids.size() - 1); | |
| 868 if (dawg_args_.permuter == NO_PERM) { | |
| 869 break; | |
| 870 } else if (i < normed_ids.size() - 1) { | |
| 871 tmp_active_dawgs = *dawg_args_.updated_dawgs; | |
| 872 dawg_args_.active_dawgs = &tmp_active_dawgs; | |
| 873 } | |
| 874 if (language_model_debug_level > 2) { | |
| 875 tprintf("Letter was OK for unichar %d, normed %d\n", b.unichar_id(), normed_ids[i]); | |
| 876 } | |
| 877 } | |
| 878 dawg_args_.active_dawgs = nullptr; | |
| 879 if (dawg_args_.permuter != NO_PERM) { | |
| 880 dawg_info = new LanguageModelDawgInfo(dawg_args_.updated_dawgs, dawg_args_.permuter); | |
| 881 } else if (language_model_debug_level > 3) { | |
| 882 tprintf("Letter %s not OK!\n", dict_->getUnicharset().id_to_unichar(b.unichar_id())); | |
| 883 } | |
| 884 | |
| 885 return dawg_info; | |
| 886 } | |
| 887 | |
| 888 LanguageModelNgramInfo *LanguageModel::GenerateNgramInfo(const char *unichar, float certainty, | |
| 889 float denom, int curr_col, int curr_row, | |
| 890 float outline_length, | |
| 891 const ViterbiStateEntry *parent_vse) { | |
| 892 // Initialize parent context. | |
| 893 const char *pcontext_ptr = ""; | |
| 894 int pcontext_unichar_step_len = 0; | |
| 895 if (parent_vse == nullptr) { | |
| 896 pcontext_ptr = prev_word_str_.c_str(); | |
| 897 pcontext_unichar_step_len = prev_word_unichar_step_len_; | |
| 898 } else { | |
| 899 pcontext_ptr = parent_vse->ngram_info->context.c_str(); | |
| 900 pcontext_unichar_step_len = parent_vse->ngram_info->context_unichar_step_len; | |
| 901 } | |
| 902 // Compute p(unichar | parent context). | |
| 903 int unichar_step_len = 0; | |
| 904 bool pruned = false; | |
| 905 float ngram_cost; | |
| 906 float ngram_and_classifier_cost = ComputeNgramCost(unichar, certainty, denom, pcontext_ptr, | |
| 907 &unichar_step_len, &pruned, &ngram_cost); | |
| 908 // Normalize just the ngram_and_classifier_cost by outline_length. | |
| 909 // The ngram_cost is used by the params_model, so it needs to be left as-is, | |
| 910 // and the params model cost will be normalized by outline_length. | |
| 911 ngram_and_classifier_cost *= outline_length / language_model_ngram_rating_factor; | |
| 912 // Add the ngram_cost of the parent. | |
| 913 if (parent_vse != nullptr) { | |
| 914 ngram_and_classifier_cost += parent_vse->ngram_info->ngram_and_classifier_cost; | |
| 915 ngram_cost += parent_vse->ngram_info->ngram_cost; | |
| 916 } | |
| 917 | |
| 918 // Shorten parent context string by unichar_step_len unichars. | |
| 919 int num_remove = (unichar_step_len + pcontext_unichar_step_len - language_model_ngram_order); | |
| 920 if (num_remove > 0) { | |
| 921 pcontext_unichar_step_len -= num_remove; | |
| 922 } | |
| 923 while (num_remove > 0 && *pcontext_ptr != '\0') { | |
| 924 pcontext_ptr += UNICHAR::utf8_step(pcontext_ptr); | |
| 925 --num_remove; | |
| 926 } | |
| 927 | |
| 928 // Decide whether to prune this ngram path and update changed accordingly. | |
| 929 if (parent_vse != nullptr && parent_vse->ngram_info->pruned) { | |
| 930 pruned = true; | |
| 931 } | |
| 932 | |
| 933 // Construct and return the new LanguageModelNgramInfo. | |
| 934 auto *ngram_info = new LanguageModelNgramInfo(pcontext_ptr, pcontext_unichar_step_len, pruned, | |
| 935 ngram_cost, ngram_and_classifier_cost); | |
| 936 ngram_info->context += unichar; | |
| 937 ngram_info->context_unichar_step_len += unichar_step_len; | |
| 938 assert(ngram_info->context_unichar_step_len <= language_model_ngram_order); | |
| 939 return ngram_info; | |
| 940 } | |
| 941 | |
| 942 float LanguageModel::ComputeNgramCost(const char *unichar, float certainty, float denom, | |
| 943 const char *context, int *unichar_step_len, | |
| 944 bool *found_small_prob, float *ngram_cost) { | |
| 945 const char *context_ptr = context; | |
| 946 char *modified_context = nullptr; | |
| 947 char *modified_context_end = nullptr; | |
| 948 const char *unichar_ptr = unichar; | |
| 949 const char *unichar_end = unichar_ptr + strlen(unichar_ptr); | |
| 950 float prob = 0.0f; | |
| 951 int step = 0; | |
| 952 while (unichar_ptr < unichar_end && (step = UNICHAR::utf8_step(unichar_ptr)) > 0) { | |
| 953 if (language_model_debug_level > 1) { | |
| 954 tprintf("prob(%s | %s)=%g\n", unichar_ptr, context_ptr, | |
| 955 dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step)); | |
| 956 } | |
| 957 prob += dict_->ProbabilityInContext(context_ptr, -1, unichar_ptr, step); | |
| 958 ++(*unichar_step_len); | |
| 959 if (language_model_ngram_use_only_first_uft8_step) { | |
| 960 break; | |
| 961 } | |
| 962 unichar_ptr += step; | |
| 963 // If there are multiple UTF8 characters present in unichar, context is | |
| 964 // updated to include the previously examined characters from str, | |
| 965 // unless use_only_first_uft8_step is true. | |
| 966 if (unichar_ptr < unichar_end) { | |
| 967 if (modified_context == nullptr) { | |
| 968 size_t context_len = strlen(context); | |
| 969 modified_context = new char[context_len + strlen(unichar_ptr) + step + 1]; | |
| 970 memcpy(modified_context, context, context_len); | |
| 971 modified_context_end = modified_context + context_len; | |
| 972 context_ptr = modified_context; | |
| 973 } | |
| 974 strncpy(modified_context_end, unichar_ptr - step, step); | |
| 975 modified_context_end += step; | |
| 976 *modified_context_end = '\0'; | |
| 977 } | |
| 978 } | |
| 979 prob /= static_cast<float>(*unichar_step_len); // normalize | |
| 980 if (prob < language_model_ngram_small_prob) { | |
| 981 if (language_model_debug_level > 0) { | |
| 982 tprintf("Found small prob %g\n", prob); | |
| 983 } | |
| 984 *found_small_prob = true; | |
| 985 prob = language_model_ngram_small_prob; | |
| 986 } | |
| 987 *ngram_cost = -1 * std::log2(prob); | |
| 988 float ngram_and_classifier_cost = -1 * std::log2(CertaintyScore(certainty) / denom) + | |
| 989 *ngram_cost * language_model_ngram_scale_factor; | |
| 990 if (language_model_debug_level > 1) { | |
| 991 tprintf("-log [ p(%s) * p(%s | %s) ] = -log2(%g*%g) = %g\n", unichar, unichar, context_ptr, | |
| 992 CertaintyScore(certainty) / denom, prob, ngram_and_classifier_cost); | |
| 993 } | |
| 994 delete[] modified_context; | |
| 995 return ngram_and_classifier_cost; | |
| 996 } | |
| 997 | |
| 998 float LanguageModel::ComputeDenom(BLOB_CHOICE_LIST *curr_list) { | |
| 999 if (curr_list->empty()) { | |
| 1000 return 1.0f; | |
| 1001 } | |
| 1002 float denom = 0.0f; | |
| 1003 int len = 0; | |
| 1004 BLOB_CHOICE_IT c_it(curr_list); | |
| 1005 for (c_it.mark_cycle_pt(); !c_it.cycled_list(); c_it.forward()) { | |
| 1006 ASSERT_HOST(c_it.data() != nullptr); | |
| 1007 ++len; | |
| 1008 denom += CertaintyScore(c_it.data()->certainty()); | |
| 1009 } | |
| 1010 assert(len != 0); | |
| 1011 // The ideal situation would be to have the classifier scores for | |
| 1012 // classifying each position as each of the characters in the unicharset. | |
| 1013 // Since we cannot do this because of speed, we add a very crude estimate | |
| 1014 // of what these scores for the "missing" classifications would sum up to. | |
| 1015 denom += | |
| 1016 (dict_->getUnicharset().size() - len) * CertaintyScore(language_model_ngram_nonmatch_score); | |
| 1017 | |
| 1018 return denom; | |
| 1019 } | |
| 1020 | |
| 1021 void LanguageModel::FillConsistencyInfo(int curr_col, bool word_end, BLOB_CHOICE *b, | |
| 1022 ViterbiStateEntry *parent_vse, WERD_RES *word_res, | |
| 1023 LMConsistencyInfo *consistency_info) { | |
| 1024 const UNICHARSET &unicharset = dict_->getUnicharset(); | |
| 1025 UNICHAR_ID unichar_id = b->unichar_id(); | |
| 1026 BLOB_CHOICE *parent_b = parent_vse != nullptr ? parent_vse->curr_b : nullptr; | |
| 1027 | |
| 1028 // Check punctuation validity. | |
| 1029 if (unicharset.get_ispunctuation(unichar_id)) { | |
| 1030 consistency_info->num_punc++; | |
| 1031 } | |
| 1032 if (dict_->GetPuncDawg() != nullptr && !consistency_info->invalid_punc) { | |
| 1033 if (dict_->compound_marker(unichar_id) && parent_b != nullptr && | |
| 1034 (unicharset.get_isalpha(parent_b->unichar_id()) || | |
| 1035 unicharset.get_isdigit(parent_b->unichar_id()))) { | |
| 1036 // reset punc_ref for compound words | |
| 1037 consistency_info->punc_ref = NO_EDGE; | |
| 1038 } else { | |
| 1039 bool is_apos = dict_->is_apostrophe(unichar_id); | |
| 1040 bool prev_is_numalpha = | |
| 1041 (parent_b != nullptr && (unicharset.get_isalpha(parent_b->unichar_id()) || | |
| 1042 unicharset.get_isdigit(parent_b->unichar_id()))); | |
| 1043 UNICHAR_ID pattern_unichar_id = | |
| 1044 (unicharset.get_isalpha(unichar_id) || unicharset.get_isdigit(unichar_id) || | |
| 1045 (is_apos && prev_is_numalpha)) | |
| 1046 ? Dawg::kPatternUnicharID | |
| 1047 : unichar_id; | |
| 1048 if (consistency_info->punc_ref == NO_EDGE || pattern_unichar_id != Dawg::kPatternUnicharID || | |
| 1049 dict_->GetPuncDawg()->edge_letter(consistency_info->punc_ref) != | |
| 1050 Dawg::kPatternUnicharID) { | |
| 1051 NODE_REF node = Dict::GetStartingNode(dict_->GetPuncDawg(), consistency_info->punc_ref); | |
| 1052 consistency_info->punc_ref = (node != NO_EDGE) ? dict_->GetPuncDawg()->edge_char_of( | |
| 1053 node, pattern_unichar_id, word_end) | |
| 1054 : NO_EDGE; | |
| 1055 if (consistency_info->punc_ref == NO_EDGE) { | |
| 1056 consistency_info->invalid_punc = true; | |
| 1057 } | |
| 1058 } | |
| 1059 } | |
| 1060 } | |
| 1061 | |
| 1062 // Update case related counters. | |
| 1063 if (parent_vse != nullptr && !word_end && dict_->compound_marker(unichar_id)) { | |
| 1064 // Reset counters if we are dealing with a compound word. | |
| 1065 consistency_info->num_lower = 0; | |
| 1066 consistency_info->num_non_first_upper = 0; | |
| 1067 } else if (unicharset.get_islower(unichar_id)) { | |
| 1068 consistency_info->num_lower++; | |
| 1069 } else if ((parent_b != nullptr) && unicharset.get_isupper(unichar_id)) { | |
| 1070 if (unicharset.get_isupper(parent_b->unichar_id()) || consistency_info->num_lower > 0 || | |
| 1071 consistency_info->num_non_first_upper > 0) { | |
| 1072 consistency_info->num_non_first_upper++; | |
| 1073 } | |
| 1074 } | |
| 1075 | |
| 1076 // Initialize consistency_info->script_id (use script of unichar_id | |
| 1077 // if it is not Common, use script id recorded by the parent otherwise). | |
| 1078 // Set inconsistent_script to true if the script of the current unichar | |
| 1079 // is not consistent with that of the parent. | |
| 1080 consistency_info->script_id = unicharset.get_script(unichar_id); | |
| 1081 // Hiragana and Katakana can mix with Han. | |
| 1082 if (dict_->getUnicharset().han_sid() != dict_->getUnicharset().null_sid()) { | |
| 1083 if ((unicharset.hiragana_sid() != unicharset.null_sid() && | |
| 1084 consistency_info->script_id == unicharset.hiragana_sid()) || | |
| 1085 (unicharset.katakana_sid() != unicharset.null_sid() && | |
| 1086 consistency_info->script_id == unicharset.katakana_sid())) { | |
| 1087 consistency_info->script_id = dict_->getUnicharset().han_sid(); | |
| 1088 } | |
| 1089 } | |
| 1090 | |
| 1091 if (parent_vse != nullptr && | |
| 1092 (parent_vse->consistency_info.script_id != dict_->getUnicharset().common_sid())) { | |
| 1093 int parent_script_id = parent_vse->consistency_info.script_id; | |
| 1094 // If script_id is Common, use script id of the parent instead. | |
| 1095 if (consistency_info->script_id == dict_->getUnicharset().common_sid()) { | |
| 1096 consistency_info->script_id = parent_script_id; | |
| 1097 } | |
| 1098 if (consistency_info->script_id != parent_script_id) { | |
| 1099 consistency_info->inconsistent_script = true; | |
| 1100 } | |
| 1101 } | |
| 1102 | |
| 1103 // Update chartype related counters. | |
| 1104 if (unicharset.get_isalpha(unichar_id)) { | |
| 1105 consistency_info->num_alphas++; | |
| 1106 } else if (unicharset.get_isdigit(unichar_id)) { | |
| 1107 consistency_info->num_digits++; | |
| 1108 } else if (!unicharset.get_ispunctuation(unichar_id)) { | |
| 1109 consistency_info->num_other++; | |
| 1110 } | |
| 1111 | |
| 1112 // Check font and spacing consistency. | |
| 1113 if (fontinfo_table_->size() > 0 && parent_b != nullptr) { | |
| 1114 int fontinfo_id = -1; | |
| 1115 if (parent_b->fontinfo_id() == b->fontinfo_id() || | |
| 1116 parent_b->fontinfo_id2() == b->fontinfo_id()) { | |
| 1117 fontinfo_id = b->fontinfo_id(); | |
| 1118 } else if (parent_b->fontinfo_id() == b->fontinfo_id2() || | |
| 1119 parent_b->fontinfo_id2() == b->fontinfo_id2()) { | |
| 1120 fontinfo_id = b->fontinfo_id2(); | |
| 1121 } | |
| 1122 if (language_model_debug_level > 1) { | |
| 1123 tprintf( | |
| 1124 "pfont %s pfont %s font %s font2 %s common %s(%d)\n", | |
| 1125 (parent_b->fontinfo_id() >= 0) ? fontinfo_table_->at(parent_b->fontinfo_id()).name : "", | |
| 1126 (parent_b->fontinfo_id2() >= 0) ? fontinfo_table_->at(parent_b->fontinfo_id2()).name | |
| 1127 : "", | |
| 1128 (b->fontinfo_id() >= 0) ? fontinfo_table_->at(b->fontinfo_id()).name : "", | |
| 1129 (fontinfo_id >= 0) ? fontinfo_table_->at(fontinfo_id).name : "", | |
| 1130 (fontinfo_id >= 0) ? fontinfo_table_->at(fontinfo_id).name : "", fontinfo_id); | |
| 1131 } | |
| 1132 if (!word_res->blob_widths.empty()) { // if we have widths/gaps info | |
| 1133 bool expected_gap_found = false; | |
| 1134 float expected_gap = 0.0f; | |
| 1135 int temp_gap; | |
| 1136 if (fontinfo_id >= 0) { // found a common font | |
| 1137 ASSERT_HOST(fontinfo_id < fontinfo_table_->size()); | |
| 1138 if (fontinfo_table_->at(fontinfo_id) | |
| 1139 .get_spacing(parent_b->unichar_id(), unichar_id, &temp_gap)) { | |
| 1140 expected_gap = temp_gap; | |
| 1141 expected_gap_found = true; | |
| 1142 } | |
| 1143 } else { | |
| 1144 consistency_info->inconsistent_font = true; | |
| 1145 // Get an average of the expected gaps in each font | |
| 1146 int num_addends = 0; | |
| 1147 int temp_fid; | |
| 1148 for (int i = 0; i < 4; ++i) { | |
| 1149 if (i == 0) { | |
| 1150 temp_fid = parent_b->fontinfo_id(); | |
| 1151 } else if (i == 1) { | |
| 1152 temp_fid = parent_b->fontinfo_id2(); | |
| 1153 } else if (i == 2) { | |
| 1154 temp_fid = b->fontinfo_id(); | |
| 1155 } else { | |
| 1156 temp_fid = b->fontinfo_id2(); | |
| 1157 } | |
| 1158 ASSERT_HOST(temp_fid < 0 || fontinfo_table_->size()); | |
| 1159 if (temp_fid >= 0 && fontinfo_table_->at(temp_fid).get_spacing(parent_b->unichar_id(), | |
| 1160 unichar_id, &temp_gap)) { | |
| 1161 expected_gap += temp_gap; | |
| 1162 num_addends++; | |
| 1163 } | |
| 1164 } | |
| 1165 if (num_addends > 0) { | |
| 1166 expected_gap /= static_cast<float>(num_addends); | |
| 1167 expected_gap_found = true; | |
| 1168 } | |
| 1169 } | |
| 1170 if (expected_gap_found) { | |
| 1171 int actual_gap = word_res->GetBlobsGap(curr_col - 1); | |
| 1172 if (actual_gap == 0) { | |
| 1173 consistency_info->num_inconsistent_spaces++; | |
| 1174 } else { | |
| 1175 float gap_ratio = expected_gap / actual_gap; | |
| 1176 // TODO(rays) The gaps seem to be way off most of the time, saved by | |
| 1177 // the error here that the ratio was compared to 1/2, when it should | |
| 1178 // have been 0.5f. Find the source of the gaps discrepancy and put | |
| 1179 // the 0.5f here in place of 0.0f. | |
| 1180 // Test on 2476595.sj, pages 0 to 6. (In French.) | |
| 1181 if (gap_ratio < 0.0f || gap_ratio > 2.0f) { | |
| 1182 consistency_info->num_inconsistent_spaces++; | |
| 1183 } | |
| 1184 } | |
| 1185 if (language_model_debug_level > 1) { | |
| 1186 tprintf("spacing for %s(%d) %s(%d) col %d: expected %g actual %d\n", | |
| 1187 unicharset.id_to_unichar(parent_b->unichar_id()), parent_b->unichar_id(), | |
| 1188 unicharset.id_to_unichar(unichar_id), unichar_id, curr_col, expected_gap, | |
| 1189 actual_gap); | |
| 1190 } | |
| 1191 } | |
| 1192 } | |
| 1193 } | |
| 1194 } | |
| 1195 | |
| 1196 float LanguageModel::ComputeAdjustedPathCost(ViterbiStateEntry *vse) { | |
| 1197 ASSERT_HOST(vse != nullptr); | |
| 1198 if (params_model_.Initialized()) { | |
| 1199 float features[PTRAIN_NUM_FEATURE_TYPES]; | |
| 1200 ExtractFeaturesFromPath(*vse, features); | |
| 1201 float cost = params_model_.ComputeCost(features); | |
| 1202 if (language_model_debug_level > 3) { | |
| 1203 tprintf("ComputeAdjustedPathCost %g ParamsModel features:\n", cost); | |
| 1204 if (language_model_debug_level >= 5) { | |
| 1205 for (int f = 0; f < PTRAIN_NUM_FEATURE_TYPES; ++f) { | |
| 1206 tprintf("%s=%g\n", kParamsTrainingFeatureTypeName[f], features[f]); | |
| 1207 } | |
| 1208 } | |
| 1209 } | |
| 1210 return cost * vse->outline_length; | |
| 1211 } else { | |
| 1212 float adjustment = 1.0f; | |
| 1213 if (vse->dawg_info == nullptr || vse->dawg_info->permuter != FREQ_DAWG_PERM) { | |
| 1214 adjustment += language_model_penalty_non_freq_dict_word; | |
| 1215 } | |
| 1216 if (vse->dawg_info == nullptr) { | |
| 1217 adjustment += language_model_penalty_non_dict_word; | |
| 1218 if (vse->length > language_model_min_compound_length) { | |
| 1219 adjustment += | |
| 1220 ((vse->length - language_model_min_compound_length) * language_model_penalty_increment); | |
| 1221 } | |
| 1222 } | |
| 1223 if (vse->associate_stats.shape_cost > 0) { | |
| 1224 adjustment += vse->associate_stats.shape_cost / static_cast<float>(vse->length); | |
| 1225 } | |
| 1226 if (language_model_ngram_on) { | |
| 1227 ASSERT_HOST(vse->ngram_info != nullptr); | |
| 1228 return vse->ngram_info->ngram_and_classifier_cost * adjustment; | |
| 1229 } else { | |
| 1230 adjustment += ComputeConsistencyAdjustment(vse->dawg_info, vse->consistency_info); | |
| 1231 return vse->ratings_sum * adjustment; | |
| 1232 } | |
| 1233 } | |
| 1234 } | |
| 1235 | |
| 1236 void LanguageModel::UpdateBestChoice(ViterbiStateEntry *vse, LMPainPoints *pain_points, | |
| 1237 WERD_RES *word_res, BestChoiceBundle *best_choice_bundle, | |
| 1238 BlamerBundle *blamer_bundle) { | |
| 1239 bool truth_path; | |
| 1240 WERD_CHOICE *word = | |
| 1241 ConstructWord(vse, word_res, &best_choice_bundle->fixpt, blamer_bundle, &truth_path); | |
| 1242 ASSERT_HOST(word != nullptr); | |
| 1243 if (dict_->stopper_debug_level >= 1) { | |
| 1244 std::string word_str; | |
| 1245 word->string_and_lengths(&word_str, nullptr); | |
| 1246 vse->Print(word_str.c_str()); | |
| 1247 } | |
| 1248 if (language_model_debug_level > 0) { | |
| 1249 word->print("UpdateBestChoice() constructed word"); | |
| 1250 } | |
| 1251 // Record features from the current path if necessary. | |
| 1252 ParamsTrainingHypothesis curr_hyp; | |
| 1253 if (blamer_bundle != nullptr) { | |
| 1254 if (vse->dawg_info != nullptr) { | |
| 1255 vse->dawg_info->permuter = static_cast<PermuterType>(word->permuter()); | |
| 1256 } | |
| 1257 ExtractFeaturesFromPath(*vse, curr_hyp.features); | |
| 1258 word->string_and_lengths(&(curr_hyp.str), nullptr); | |
| 1259 curr_hyp.cost = vse->cost; // record cost for error rate computations | |
| 1260 if (language_model_debug_level > 0) { | |
| 1261 tprintf("Raw features extracted from %s (cost=%g) [ ", curr_hyp.str.c_str(), curr_hyp.cost); | |
| 1262 for (float feature : curr_hyp.features) { | |
| 1263 tprintf("%g ", feature); | |
| 1264 } | |
| 1265 tprintf("]\n"); | |
| 1266 } | |
| 1267 // Record the current hypothesis in params_training_bundle. | |
| 1268 blamer_bundle->AddHypothesis(curr_hyp); | |
| 1269 if (truth_path) { | |
| 1270 blamer_bundle->UpdateBestRating(word->rating()); | |
| 1271 } | |
| 1272 } | |
| 1273 if (blamer_bundle != nullptr && blamer_bundle->GuidedSegsearchStillGoing()) { | |
| 1274 // The word was constructed solely for blamer_bundle->AddHypothesis, so | |
| 1275 // we no longer need it. | |
| 1276 delete word; | |
| 1277 return; | |
| 1278 } | |
| 1279 if (word_res->chopped_word != nullptr && !word_res->chopped_word->blobs.empty()) { | |
| 1280 word->SetScriptPositions(false, word_res->chopped_word, language_model_debug_level); | |
| 1281 } | |
| 1282 // Update and log new raw_choice if needed. | |
| 1283 if (word_res->raw_choice == nullptr || word->rating() < word_res->raw_choice->rating()) { | |
| 1284 if (word_res->LogNewRawChoice(word) && language_model_debug_level > 0) { | |
| 1285 tprintf("Updated raw choice\n"); | |
| 1286 } | |
| 1287 } | |
| 1288 // Set the modified rating for best choice to vse->cost and log best choice. | |
| 1289 word->set_rating(vse->cost); | |
| 1290 // Call LogNewChoice() for best choice from Dict::adjust_word() since it | |
| 1291 // computes adjust_factor that is used by the adaption code (e.g. by | |
| 1292 // ClassifyAdaptableWord() to compute adaption acceptance thresholds). | |
| 1293 // Note: the rating of the word is not adjusted. | |
| 1294 dict_->adjust_word(word, vse->dawg_info == nullptr, vse->consistency_info.xht_decision, 0.0, | |
| 1295 false, language_model_debug_level > 0); | |
| 1296 // Hand ownership of the word over to the word_res. | |
| 1297 if (!word_res->LogNewCookedChoice(dict_->tessedit_truncate_wordchoice_log, | |
| 1298 dict_->stopper_debug_level >= 1, word)) { | |
| 1299 // The word was so bad that it was deleted. | |
| 1300 return; | |
| 1301 } | |
| 1302 if (word_res->best_choice == word) { | |
| 1303 // Word was the new best. | |
| 1304 if (dict_->AcceptableChoice(*word, vse->consistency_info.xht_decision) && | |
| 1305 AcceptablePath(*vse)) { | |
| 1306 acceptable_choice_found_ = true; | |
| 1307 } | |
| 1308 // Update best_choice_bundle. | |
| 1309 best_choice_bundle->updated = true; | |
| 1310 best_choice_bundle->best_vse = vse; | |
| 1311 if (language_model_debug_level > 0) { | |
| 1312 tprintf("Updated best choice\n"); | |
| 1313 word->print_state("New state "); | |
| 1314 } | |
| 1315 // Update hyphen state if we are dealing with a dictionary word. | |
| 1316 if (vse->dawg_info != nullptr) { | |
| 1317 if (dict_->has_hyphen_end(*word)) { | |
| 1318 dict_->set_hyphen_word(*word, *(dawg_args_.active_dawgs)); | |
| 1319 } else { | |
| 1320 dict_->reset_hyphen_vars(true); | |
| 1321 } | |
| 1322 } | |
| 1323 | |
| 1324 if (blamer_bundle != nullptr) { | |
| 1325 blamer_bundle->set_best_choice_is_dict_and_top_choice(vse->dawg_info != nullptr && | |
| 1326 vse->top_choice_flags); | |
| 1327 } | |
| 1328 } | |
| 1329 #ifndef GRAPHICS_DISABLED | |
| 1330 if (wordrec_display_segmentations && word_res->chopped_word != nullptr) { | |
| 1331 word->DisplaySegmentation(word_res->chopped_word); | |
| 1332 } | |
| 1333 #endif | |
| 1334 } | |
| 1335 | |
| 1336 void LanguageModel::ExtractFeaturesFromPath(const ViterbiStateEntry &vse, float features[]) { | |
| 1337 memset(features, 0, sizeof(float) * PTRAIN_NUM_FEATURE_TYPES); | |
| 1338 // Record dictionary match info. | |
| 1339 int len = vse.length <= kMaxSmallWordUnichars ? 0 : vse.length <= kMaxMediumWordUnichars ? 1 : 2; | |
| 1340 if (vse.dawg_info != nullptr) { | |
| 1341 int permuter = vse.dawg_info->permuter; | |
| 1342 if (permuter == NUMBER_PERM || permuter == USER_PATTERN_PERM) { | |
| 1343 if (vse.consistency_info.num_digits == vse.length) { | |
| 1344 features[PTRAIN_DIGITS_SHORT + len] = 1.0f; | |
| 1345 } else { | |
| 1346 features[PTRAIN_NUM_SHORT + len] = 1.0f; | |
| 1347 } | |
| 1348 } else if (permuter == DOC_DAWG_PERM) { | |
| 1349 features[PTRAIN_DOC_SHORT + len] = 1.0f; | |
| 1350 } else if (permuter == SYSTEM_DAWG_PERM || permuter == USER_DAWG_PERM || | |
| 1351 permuter == COMPOUND_PERM) { | |
| 1352 features[PTRAIN_DICT_SHORT + len] = 1.0f; | |
| 1353 } else if (permuter == FREQ_DAWG_PERM) { | |
| 1354 features[PTRAIN_FREQ_SHORT + len] = 1.0f; | |
| 1355 } | |
| 1356 } | |
| 1357 // Record shape cost feature (normalized by path length). | |
| 1358 features[PTRAIN_SHAPE_COST_PER_CHAR] = | |
| 1359 vse.associate_stats.shape_cost / static_cast<float>(vse.length); | |
| 1360 // Record ngram cost. (normalized by the path length). | |
| 1361 features[PTRAIN_NGRAM_COST_PER_CHAR] = 0.0f; | |
| 1362 if (vse.ngram_info != nullptr) { | |
| 1363 features[PTRAIN_NGRAM_COST_PER_CHAR] = | |
| 1364 vse.ngram_info->ngram_cost / static_cast<float>(vse.length); | |
| 1365 } | |
| 1366 // Record consistency-related features. | |
| 1367 // Disabled this feature for due to its poor performance. | |
| 1368 // features[PTRAIN_NUM_BAD_PUNC] = vse.consistency_info.NumInconsistentPunc(); | |
| 1369 features[PTRAIN_NUM_BAD_CASE] = vse.consistency_info.NumInconsistentCase(); | |
| 1370 features[PTRAIN_XHEIGHT_CONSISTENCY] = vse.consistency_info.xht_decision; | |
| 1371 features[PTRAIN_NUM_BAD_CHAR_TYPE] = | |
| 1372 vse.dawg_info == nullptr ? vse.consistency_info.NumInconsistentChartype() : 0.0f; | |
| 1373 features[PTRAIN_NUM_BAD_SPACING] = vse.consistency_info.NumInconsistentSpaces(); | |
| 1374 // Disabled this feature for now due to its poor performance. | |
| 1375 // features[PTRAIN_NUM_BAD_FONT] = vse.consistency_info.inconsistent_font; | |
| 1376 | |
| 1377 // Classifier-related features. | |
| 1378 if (vse.outline_length > 0.0f) { | |
| 1379 features[PTRAIN_RATING_PER_CHAR] = vse.ratings_sum / vse.outline_length; | |
| 1380 } else { | |
| 1381 // Avoid FP division by 0. | |
| 1382 features[PTRAIN_RATING_PER_CHAR] = 0.0f; | |
| 1383 } | |
| 1384 } | |
| 1385 | |
| 1386 WERD_CHOICE *LanguageModel::ConstructWord(ViterbiStateEntry *vse, WERD_RES *word_res, | |
| 1387 DANGERR *fixpt, BlamerBundle *blamer_bundle, | |
| 1388 bool *truth_path) { | |
| 1389 if (truth_path != nullptr) { | |
| 1390 *truth_path = | |
| 1391 (blamer_bundle != nullptr && vse->length == blamer_bundle->correct_segmentation_length()); | |
| 1392 } | |
| 1393 BLOB_CHOICE *curr_b = vse->curr_b; | |
| 1394 ViterbiStateEntry *curr_vse = vse; | |
| 1395 | |
| 1396 int i; | |
| 1397 bool compound = dict_->hyphenated(); // treat hyphenated words as compound | |
| 1398 | |
| 1399 // Re-compute the variance of the width-to-height ratios (since we now | |
| 1400 // can compute the mean over the whole word). | |
| 1401 float full_wh_ratio_mean = 0.0f; | |
| 1402 if (vse->associate_stats.full_wh_ratio_var != 0.0f) { | |
| 1403 vse->associate_stats.shape_cost -= vse->associate_stats.full_wh_ratio_var; | |
| 1404 full_wh_ratio_mean = | |
| 1405 (vse->associate_stats.full_wh_ratio_total / static_cast<float>(vse->length)); | |
| 1406 vse->associate_stats.full_wh_ratio_var = 0.0f; | |
| 1407 } | |
| 1408 | |
| 1409 // Construct a WERD_CHOICE by tracing parent pointers. | |
| 1410 auto *word = new WERD_CHOICE(word_res->uch_set, vse->length); | |
| 1411 word->set_length(vse->length); | |
| 1412 int total_blobs = 0; | |
| 1413 for (i = (vse->length - 1); i >= 0; --i) { | |
| 1414 if (blamer_bundle != nullptr && truth_path != nullptr && *truth_path && | |
| 1415 !blamer_bundle->MatrixPositionCorrect(i, curr_b->matrix_cell())) { | |
| 1416 *truth_path = false; | |
| 1417 } | |
| 1418 // The number of blobs used for this choice is row - col + 1. | |
| 1419 int num_blobs = curr_b->matrix_cell().row - curr_b->matrix_cell().col + 1; | |
| 1420 total_blobs += num_blobs; | |
| 1421 word->set_blob_choice(i, num_blobs, curr_b); | |
| 1422 // Update the width-to-height ratio variance. Useful non-space delimited | |
| 1423 // languages to ensure that the blobs are of uniform width. | |
| 1424 // Skip leading and trailing punctuation when computing the variance. | |
| 1425 if ((full_wh_ratio_mean != 0.0f && | |
| 1426 ((curr_vse != vse && curr_vse->parent_vse != nullptr) || | |
| 1427 !dict_->getUnicharset().get_ispunctuation(curr_b->unichar_id())))) { | |
| 1428 vse->associate_stats.full_wh_ratio_var += | |
| 1429 pow(full_wh_ratio_mean - curr_vse->associate_stats.full_wh_ratio, 2); | |
| 1430 if (language_model_debug_level > 2) { | |
| 1431 tprintf("full_wh_ratio_var += (%g-%g)^2\n", full_wh_ratio_mean, | |
| 1432 curr_vse->associate_stats.full_wh_ratio); | |
| 1433 } | |
| 1434 } | |
| 1435 | |
| 1436 // Mark the word as compound if compound permuter was set for any of | |
| 1437 // the unichars on the path (usually this will happen for unichars | |
| 1438 // that are compounding operators, like "-" and "/"). | |
| 1439 if (!compound && curr_vse->dawg_info && curr_vse->dawg_info->permuter == COMPOUND_PERM) { | |
| 1440 compound = true; | |
| 1441 } | |
| 1442 | |
| 1443 // Update curr_* pointers. | |
| 1444 curr_vse = curr_vse->parent_vse; | |
| 1445 if (curr_vse == nullptr) { | |
| 1446 break; | |
| 1447 } | |
| 1448 curr_b = curr_vse->curr_b; | |
| 1449 } | |
| 1450 ASSERT_HOST(i == 0); // check that we recorded all the unichar ids. | |
| 1451 ASSERT_HOST(total_blobs == word_res->ratings->dimension()); | |
| 1452 // Re-adjust shape cost to include the updated width-to-height variance. | |
| 1453 if (full_wh_ratio_mean != 0.0f) { | |
| 1454 vse->associate_stats.shape_cost += vse->associate_stats.full_wh_ratio_var; | |
| 1455 } | |
| 1456 | |
| 1457 word->set_rating(vse->ratings_sum); | |
| 1458 word->set_certainty(vse->min_certainty); | |
| 1459 word->set_x_heights(vse->consistency_info.BodyMinXHeight(), | |
| 1460 vse->consistency_info.BodyMaxXHeight()); | |
| 1461 if (vse->dawg_info != nullptr) { | |
| 1462 word->set_permuter(compound ? COMPOUND_PERM : vse->dawg_info->permuter); | |
| 1463 } else if (language_model_ngram_on && !vse->ngram_info->pruned) { | |
| 1464 word->set_permuter(NGRAM_PERM); | |
| 1465 } else if (vse->top_choice_flags) { | |
| 1466 word->set_permuter(TOP_CHOICE_PERM); | |
| 1467 } else { | |
| 1468 word->set_permuter(NO_PERM); | |
| 1469 } | |
| 1470 word->set_dangerous_ambig_found_(!dict_->NoDangerousAmbig(word, fixpt, true, word_res->ratings)); | |
| 1471 return word; | |
| 1472 } | |
| 1473 | |
| 1474 } // namespace tesseract |
