Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/lstmrecognizer.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstmrecognizer.cpp | |
| 3 // Description: Top-level line recognizer class for LSTM-based networks. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 // Include automatically generated configuration file if running autoconf. | |
| 19 #ifdef HAVE_CONFIG_H | |
| 20 # include "config_auto.h" | |
| 21 #endif | |
| 22 | |
| 23 #include "lstmrecognizer.h" | |
| 24 | |
| 25 #include <allheaders.h> | |
| 26 #include "dict.h" | |
| 27 #include "genericheap.h" | |
| 28 #include "helpers.h" | |
| 29 #include "imagedata.h" | |
| 30 #include "input.h" | |
| 31 #include "lstm.h" | |
| 32 #include "normalis.h" | |
| 33 #include "pageres.h" | |
| 34 #include "ratngs.h" | |
| 35 #include "recodebeam.h" | |
| 36 #include "scrollview.h" | |
| 37 #include "statistc.h" | |
| 38 #include "tprintf.h" | |
| 39 | |
| 40 #include <unordered_set> | |
| 41 #include <vector> | |
| 42 | |
| 43 namespace tesseract { | |
| 44 | |
| 45 // Default ratio between dict and non-dict words. | |
| 46 const double kDictRatio = 2.25; | |
| 47 // Default certainty offset to give the dictionary a chance. | |
| 48 const double kCertOffset = -0.085; | |
| 49 | |
| 50 LSTMRecognizer::LSTMRecognizer(const std::string &language_data_path_prefix) | |
| 51 : LSTMRecognizer::LSTMRecognizer() { | |
| 52 ccutil_.language_data_path_prefix = language_data_path_prefix; | |
| 53 } | |
| 54 | |
| 55 LSTMRecognizer::LSTMRecognizer() | |
| 56 : network_(nullptr) | |
| 57 , training_flags_(0) | |
| 58 , training_iteration_(0) | |
| 59 , sample_iteration_(0) | |
| 60 , null_char_(UNICHAR_BROKEN) | |
| 61 , learning_rate_(0.0f) | |
| 62 , momentum_(0.0f) | |
| 63 , adam_beta_(0.0f) | |
| 64 , dict_(nullptr) | |
| 65 , search_(nullptr) | |
| 66 , debug_win_(nullptr) {} | |
| 67 | |
| 68 LSTMRecognizer::~LSTMRecognizer() { | |
| 69 delete network_; | |
| 70 delete dict_; | |
| 71 delete search_; | |
| 72 } | |
| 73 | |
| 74 // Loads a model from mgr, including the dictionary only if lang is not null. | |
| 75 bool LSTMRecognizer::Load(const ParamsVectors *params, const std::string &lang, | |
| 76 TessdataManager *mgr) { | |
| 77 TFile fp; | |
| 78 if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) { | |
| 79 return false; | |
| 80 } | |
| 81 if (!DeSerialize(mgr, &fp)) { | |
| 82 return false; | |
| 83 } | |
| 84 if (lang.empty()) { | |
| 85 return true; | |
| 86 } | |
| 87 // Allow it to run without a dictionary. | |
| 88 LoadDictionary(params, lang, mgr); | |
| 89 return true; | |
| 90 } | |
| 91 | |
| 92 // Writes to the given file. Returns false in case of error. | |
| 93 bool LSTMRecognizer::Serialize(const TessdataManager *mgr, TFile *fp) const { | |
| 94 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || | |
| 95 !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); | |
| 96 if (!network_->Serialize(fp)) { | |
| 97 return false; | |
| 98 } | |
| 99 if (include_charsets && !GetUnicharset().save_to_file(fp)) { | |
| 100 return false; | |
| 101 } | |
| 102 if (!fp->Serialize(network_str_)) { | |
| 103 return false; | |
| 104 } | |
| 105 if (!fp->Serialize(&training_flags_)) { | |
| 106 return false; | |
| 107 } | |
| 108 if (!fp->Serialize(&training_iteration_)) { | |
| 109 return false; | |
| 110 } | |
| 111 if (!fp->Serialize(&sample_iteration_)) { | |
| 112 return false; | |
| 113 } | |
| 114 if (!fp->Serialize(&null_char_)) { | |
| 115 return false; | |
| 116 } | |
| 117 if (!fp->Serialize(&adam_beta_)) { | |
| 118 return false; | |
| 119 } | |
| 120 if (!fp->Serialize(&learning_rate_)) { | |
| 121 return false; | |
| 122 } | |
| 123 if (!fp->Serialize(&momentum_)) { | |
| 124 return false; | |
| 125 } | |
| 126 if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) { | |
| 127 return false; | |
| 128 } | |
| 129 return true; | |
| 130 } | |
| 131 | |
| 132 // Reads from the given file. Returns false in case of error. | |
| 133 bool LSTMRecognizer::DeSerialize(const TessdataManager *mgr, TFile *fp) { | |
| 134 delete network_; | |
| 135 network_ = Network::CreateFromFile(fp); | |
| 136 if (network_ == nullptr) { | |
| 137 return false; | |
| 138 } | |
| 139 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || | |
| 140 !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); | |
| 141 if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) { | |
| 142 return false; | |
| 143 } | |
| 144 if (!fp->DeSerialize(network_str_)) { | |
| 145 return false; | |
| 146 } | |
| 147 if (!fp->DeSerialize(&training_flags_)) { | |
| 148 return false; | |
| 149 } | |
| 150 if (!fp->DeSerialize(&training_iteration_)) { | |
| 151 return false; | |
| 152 } | |
| 153 if (!fp->DeSerialize(&sample_iteration_)) { | |
| 154 return false; | |
| 155 } | |
| 156 if (!fp->DeSerialize(&null_char_)) { | |
| 157 return false; | |
| 158 } | |
| 159 if (!fp->DeSerialize(&adam_beta_)) { | |
| 160 return false; | |
| 161 } | |
| 162 if (!fp->DeSerialize(&learning_rate_)) { | |
| 163 return false; | |
| 164 } | |
| 165 if (!fp->DeSerialize(&momentum_)) { | |
| 166 return false; | |
| 167 } | |
| 168 if (include_charsets && !LoadRecoder(fp)) { | |
| 169 return false; | |
| 170 } | |
| 171 if (!include_charsets && !LoadCharsets(mgr)) { | |
| 172 return false; | |
| 173 } | |
| 174 network_->SetRandomizer(&randomizer_); | |
| 175 network_->CacheXScaleFactor(network_->XScaleFactor()); | |
| 176 return true; | |
| 177 } | |
| 178 | |
| 179 // Loads the charsets from mgr. | |
| 180 bool LSTMRecognizer::LoadCharsets(const TessdataManager *mgr) { | |
| 181 TFile fp; | |
| 182 if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) { | |
| 183 return false; | |
| 184 } | |
| 185 if (!ccutil_.unicharset.load_from_file(&fp, false)) { | |
| 186 return false; | |
| 187 } | |
| 188 if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) { | |
| 189 return false; | |
| 190 } | |
| 191 if (!LoadRecoder(&fp)) { | |
| 192 return false; | |
| 193 } | |
| 194 return true; | |
| 195 } | |
| 196 | |
| 197 // Loads the Recoder. | |
| 198 bool LSTMRecognizer::LoadRecoder(TFile *fp) { | |
| 199 if (IsRecoding()) { | |
| 200 if (!recoder_.DeSerialize(fp)) { | |
| 201 return false; | |
| 202 } | |
| 203 RecodedCharID code; | |
| 204 recoder_.EncodeUnichar(UNICHAR_SPACE, &code); | |
| 205 if (code(0) != UNICHAR_SPACE) { | |
| 206 tprintf("Space was garbled in recoding!!\n"); | |
| 207 return false; | |
| 208 } | |
| 209 } else { | |
| 210 recoder_.SetupPassThrough(GetUnicharset()); | |
| 211 training_flags_ |= TF_COMPRESS_UNICHARSET; | |
| 212 } | |
| 213 return true; | |
| 214 } | |
| 215 | |
| 216 // Loads the dictionary if possible from the traineddata file. | |
| 217 // Prints a warning message, and returns false but otherwise fails silently | |
| 218 // and continues to work without it if loading fails. | |
| 219 // Note that dictionary load is independent from DeSerialize, but dependent | |
| 220 // on the unicharset matching. This enables training to deserialize a model | |
| 221 // from checkpoint or restore without having to go back and reload the | |
| 222 // dictionary. | |
| 223 // Some parameters have to be passed in (from langdata/config/api via Tesseract) | |
| 224 bool LSTMRecognizer::LoadDictionary(const ParamsVectors *params, const std::string &lang, | |
| 225 TessdataManager *mgr) { | |
| 226 delete dict_; | |
| 227 dict_ = new Dict(&ccutil_); | |
| 228 dict_->user_words_file.ResetFrom(params); | |
| 229 dict_->user_words_suffix.ResetFrom(params); | |
| 230 dict_->user_patterns_file.ResetFrom(params); | |
| 231 dict_->user_patterns_suffix.ResetFrom(params); | |
| 232 dict_->SetupForLoad(Dict::GlobalDawgCache()); | |
| 233 dict_->LoadLSTM(lang, mgr); | |
| 234 if (dict_->FinishLoad()) { | |
| 235 return true; // Success. | |
| 236 } | |
| 237 if (log_level <= 0) { | |
| 238 tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str()); | |
| 239 } | |
| 240 delete dict_; | |
| 241 dict_ = nullptr; | |
| 242 return false; | |
| 243 } | |
| 244 | |
| 245 // Recognizes the line image, contained within image_data, returning the | |
| 246 // ratings matrix and matching box_word for each WERD_RES in the output. | |
| 247 void LSTMRecognizer::RecognizeLine(const ImageData &image_data, | |
| 248 float invert_threshold, bool debug, | |
| 249 double worst_dict_cert, const TBOX &line_box, | |
| 250 PointerVector<WERD_RES> *words, int lstm_choice_mode, | |
| 251 int lstm_choice_amount) { | |
| 252 NetworkIO outputs; | |
| 253 float scale_factor; | |
| 254 NetworkIO inputs; | |
| 255 if (!RecognizeLine(image_data, invert_threshold, debug, false, false, &scale_factor, &inputs, &outputs)) { | |
| 256 return; | |
| 257 } | |
| 258 if (search_ == nullptr) { | |
| 259 search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); | |
| 260 } | |
| 261 search_->excludedUnichars.clear(); | |
| 262 search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, &GetUnicharset(), | |
| 263 lstm_choice_mode); | |
| 264 search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words, | |
| 265 lstm_choice_mode); | |
| 266 if (lstm_choice_mode) { | |
| 267 search_->extractSymbolChoices(&GetUnicharset()); | |
| 268 for (int i = 0; i < lstm_choice_amount; ++i) { | |
| 269 search_->DecodeSecondaryBeams(outputs, kDictRatio, kCertOffset, worst_dict_cert, | |
| 270 &GetUnicharset(), lstm_choice_mode); | |
| 271 search_->extractSymbolChoices(&GetUnicharset()); | |
| 272 } | |
| 273 search_->segmentTimestepsByCharacters(); | |
| 274 unsigned char_it = 0; | |
| 275 for (size_t i = 0; i < words->size(); ++i) { | |
| 276 for (int j = 0; j < words->at(i)->end; ++j) { | |
| 277 if (char_it < search_->ctc_choices.size()) { | |
| 278 words->at(i)->CTC_symbol_choices.push_back(search_->ctc_choices[char_it]); | |
| 279 } | |
| 280 if (char_it < search_->segmentedTimesteps.size()) { | |
| 281 words->at(i)->segmented_timesteps.push_back(search_->segmentedTimesteps[char_it]); | |
| 282 } | |
| 283 ++char_it; | |
| 284 } | |
| 285 words->at(i)->timesteps = | |
| 286 search_->combineSegmentedTimesteps(&words->at(i)->segmented_timesteps); | |
| 287 } | |
| 288 search_->segmentedTimesteps.clear(); | |
| 289 search_->ctc_choices.clear(); | |
| 290 search_->excludedUnichars.clear(); | |
| 291 } | |
| 292 } | |
| 293 | |
| 294 // Helper computes min and mean best results in the output. | |
| 295 void LSTMRecognizer::OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, | |
| 296 float *sd) { | |
| 297 const int kOutputScale = INT8_MAX; | |
| 298 STATS stats(0, kOutputScale); | |
| 299 for (int t = 0; t < outputs.Width(); ++t) { | |
| 300 int best_label = outputs.BestLabel(t, nullptr); | |
| 301 if (best_label != null_char_) { | |
| 302 float best_output = outputs.f(t)[best_label]; | |
| 303 stats.add(static_cast<int>(kOutputScale * best_output), 1); | |
| 304 } | |
| 305 } | |
| 306 // If the output is all nulls it could be that the photometric interpretation | |
| 307 // is wrong, so make it look bad, so the other way can win, even if not great. | |
| 308 if (stats.get_total() == 0) { | |
| 309 *min_output = 0.0f; | |
| 310 *mean_output = 0.0f; | |
| 311 *sd = 1.0f; | |
| 312 } else { | |
| 313 *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale; | |
| 314 *mean_output = stats.mean() / kOutputScale; | |
| 315 *sd = stats.sd() / kOutputScale; | |
| 316 } | |
| 317 } | |
| 318 | |
| 319 // Recognizes the image_data, returning the labels, | |
| 320 // scores, and corresponding pairs of start, end x-coords in coords. | |
| 321 bool LSTMRecognizer::RecognizeLine(const ImageData &image_data, | |
| 322 float invert_threshold, bool debug, | |
| 323 bool re_invert, bool upside_down, float *scale_factor, | |
| 324 NetworkIO *inputs, NetworkIO *outputs) { | |
| 325 // This ensures consistent recognition results. | |
| 326 SetRandomSeed(); | |
| 327 int min_width = network_->XScaleFactor(); | |
| 328 Image pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor); | |
| 329 if (pix == nullptr) { | |
| 330 tprintf("Line cannot be recognized!!\n"); | |
| 331 return false; | |
| 332 } | |
| 333 // Maximum width of image to train on. | |
| 334 const int kMaxImageWidth = 128 * pixGetHeight(pix); | |
| 335 if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) { | |
| 336 tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix)); | |
| 337 pix.destroy(); | |
| 338 return false; | |
| 339 } | |
| 340 if (upside_down) { | |
| 341 pixRotate180(pix, pix); | |
| 342 } | |
| 343 // Reduction factor from image to coords. | |
| 344 *scale_factor = min_width / *scale_factor; | |
| 345 inputs->set_int_mode(IsIntMode()); | |
| 346 SetRandomSeed(); | |
| 347 Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs); | |
| 348 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs); | |
| 349 // Check for auto inversion. | |
| 350 if (invert_threshold > 0.0f) { | |
| 351 float pos_min, pos_mean, pos_sd; | |
| 352 OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd); | |
| 353 if (pos_mean < invert_threshold) { | |
| 354 // Run again inverted and see if it is any better. | |
| 355 NetworkIO inv_inputs, inv_outputs; | |
| 356 inv_inputs.set_int_mode(IsIntMode()); | |
| 357 SetRandomSeed(); | |
| 358 pixInvert(pix, pix); | |
| 359 Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs); | |
| 360 network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs); | |
| 361 float inv_min, inv_mean, inv_sd; | |
| 362 OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd); | |
| 363 if (inv_mean > pos_mean) { | |
| 364 // Inverted did better. Use inverted data. | |
| 365 if (debug) { | |
| 366 tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean, | |
| 367 pos_sd, inv_min, inv_mean, inv_sd); | |
| 368 } | |
| 369 *outputs = std::move(inv_outputs); | |
| 370 *inputs = std::move(inv_inputs); | |
| 371 } else if (re_invert) { | |
| 372 // Inverting was not an improvement, so undo and run again, so the | |
| 373 // outputs match the best forward result. | |
| 374 SetRandomSeed(); | |
| 375 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs); | |
| 376 } | |
| 377 } | |
| 378 } | |
| 379 | |
| 380 pix.destroy(); | |
| 381 if (debug) { | |
| 382 std::vector<int> labels, coords; | |
| 383 LabelsFromOutputs(*outputs, &labels, &coords); | |
| 384 #ifndef GRAPHICS_DISABLED | |
| 385 DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_); | |
| 386 #endif | |
| 387 DebugActivationPath(*outputs, labels, coords); | |
| 388 } | |
| 389 return true; | |
| 390 } | |
| 391 | |
| 392 // Converts an array of labels to utf-8, whether or not the labels are | |
| 393 // augmented with character boundaries. | |
| 394 std::string LSTMRecognizer::DecodeLabels(const std::vector<int> &labels) { | |
| 395 std::string result; | |
| 396 unsigned end = 1; | |
| 397 for (unsigned start = 0; start < labels.size(); start = end) { | |
| 398 if (labels[start] == null_char_) { | |
| 399 end = start + 1; | |
| 400 } else { | |
| 401 result += DecodeLabel(labels, start, &end, nullptr); | |
| 402 } | |
| 403 } | |
| 404 return result; | |
| 405 } | |
| 406 | |
| 407 #ifndef GRAPHICS_DISABLED | |
| 408 | |
| 409 // Displays the forward results in a window with the characters and | |
| 410 // boundaries as determined by the labels and label_coords. | |
| 411 void LSTMRecognizer::DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels, | |
| 412 const std::vector<int> &label_coords, const char *window_name, | |
| 413 ScrollView **window) { | |
| 414 Image input_pix = inputs.ToPix(); | |
| 415 Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window); | |
| 416 int line_height = Network::DisplayImage(input_pix, *window); | |
| 417 DisplayLSTMOutput(labels, label_coords, line_height, *window); | |
| 418 } | |
| 419 | |
| 420 // Displays the labels and cuts at the corresponding xcoords. | |
| 421 // Size of labels should match xcoords. | |
| 422 void LSTMRecognizer::DisplayLSTMOutput(const std::vector<int> &labels, | |
| 423 const std::vector<int> &xcoords, int height, | |
| 424 ScrollView *window) { | |
| 425 int x_scale = network_->XScaleFactor(); | |
| 426 window->TextAttributes("Arial", height / 4, false, false, false); | |
| 427 unsigned end = 1; | |
| 428 for (unsigned start = 0; start < labels.size(); start = end) { | |
| 429 int xpos = xcoords[start] * x_scale; | |
| 430 if (labels[start] == null_char_) { | |
| 431 end = start + 1; | |
| 432 window->Pen(ScrollView::RED); | |
| 433 } else { | |
| 434 window->Pen(ScrollView::GREEN); | |
| 435 const char *str = DecodeLabel(labels, start, &end, nullptr); | |
| 436 if (*str == '\\') { | |
| 437 str = "\\\\"; | |
| 438 } | |
| 439 xpos = xcoords[(start + end) / 2] * x_scale; | |
| 440 window->Text(xpos, height, str); | |
| 441 } | |
| 442 window->Line(xpos, 0, xpos, height * 3 / 2); | |
| 443 } | |
| 444 window->Update(); | |
| 445 } | |
| 446 | |
| 447 #endif // !GRAPHICS_DISABLED | |
| 448 | |
| 449 // Prints debug output detailing the activation path that is implied by the | |
| 450 // label_coords. | |
| 451 void LSTMRecognizer::DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels, | |
| 452 const std::vector<int> &xcoords) { | |
| 453 if (xcoords[0] > 0) { | |
| 454 DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]); | |
| 455 } | |
| 456 unsigned end = 1; | |
| 457 for (unsigned start = 0; start < labels.size(); start = end) { | |
| 458 if (labels[start] == null_char_) { | |
| 459 end = start + 1; | |
| 460 DebugActivationRange(outputs, "<null>", null_char_, xcoords[start], xcoords[end]); | |
| 461 continue; | |
| 462 } else { | |
| 463 int decoded; | |
| 464 const char *label = DecodeLabel(labels, start, &end, &decoded); | |
| 465 DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]); | |
| 466 for (unsigned i = start + 1; i < end; ++i) { | |
| 467 DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i], | |
| 468 xcoords[i + 1]); | |
| 469 } | |
| 470 } | |
| 471 } | |
| 472 } | |
| 473 | |
| 474 // Prints debug output detailing activations and 2nd choice over a range | |
| 475 // of positions. | |
| 476 void LSTMRecognizer::DebugActivationRange(const NetworkIO &outputs, const char *label, | |
| 477 int best_choice, int x_start, int x_end) { | |
| 478 tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end); | |
| 479 double max_score = 0.0; | |
| 480 double mean_score = 0.0; | |
| 481 const int width = x_end - x_start; | |
| 482 for (int x = x_start; x < x_end; ++x) { | |
| 483 const float *line = outputs.f(x); | |
| 484 const double score = line[best_choice] * 100.0; | |
| 485 if (score > max_score) { | |
| 486 max_score = score; | |
| 487 } | |
| 488 mean_score += score / width; | |
| 489 int best_c = 0; | |
| 490 double best_score = 0.0; | |
| 491 for (int c = 0; c < outputs.NumFeatures(); ++c) { | |
| 492 if (c != best_choice && line[c] > best_score) { | |
| 493 best_c = c; | |
| 494 best_score = line[c]; | |
| 495 } | |
| 496 } | |
| 497 tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0); | |
| 498 } | |
| 499 tprintf(", Mean=%g, max=%g\n", mean_score, max_score); | |
| 500 } | |
| 501 | |
| 502 // Helper returns true if the null_char is the winner at t, and it beats the | |
| 503 // null_threshold, or the next choice is space, in which case we will use the | |
| 504 // null anyway. | |
| 505 #if 0 // TODO: unused, remove if still unused after 2020. | |
| 506 static bool NullIsBest(const NetworkIO& output, float null_thr, | |
| 507 int null_char, int t) { | |
| 508 if (output.f(t)[null_char] >= null_thr) return true; | |
| 509 if (output.BestLabel(t, null_char, null_char, nullptr) != UNICHAR_SPACE) | |
| 510 return false; | |
| 511 return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE]; | |
| 512 } | |
| 513 #endif | |
| 514 | |
| 515 // Converts the network output to a sequence of labels. Outputs labels, scores | |
| 516 // and start xcoords of each char, and each null_char_, with an additional | |
| 517 // final xcoord for the end of the output. | |
| 518 // The conversion method is determined by internal state. | |
| 519 void LSTMRecognizer::LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels, | |
| 520 std::vector<int> *xcoords) { | |
| 521 if (SimpleTextOutput()) { | |
| 522 LabelsViaSimpleText(outputs, labels, xcoords); | |
| 523 } else { | |
| 524 LabelsViaReEncode(outputs, labels, xcoords); | |
| 525 } | |
| 526 } | |
| 527 | |
| 528 // As LabelsViaCTC except that this function constructs the best path that | |
| 529 // contains only legal sequences of subcodes for CJK. | |
| 530 void LSTMRecognizer::LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels, | |
| 531 std::vector<int> *xcoords) { | |
| 532 if (search_ == nullptr) { | |
| 533 search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_); | |
| 534 } | |
| 535 search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr); | |
| 536 search_->ExtractBestPathAsLabels(labels, xcoords); | |
| 537 } | |
| 538 | |
| 539 // Converts the network output to a sequence of labels, with scores, using | |
| 540 // the simple character model (each position is a char, and the null_char_ is | |
| 541 // mainly intended for tail padding.) | |
| 542 void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels, | |
| 543 std::vector<int> *xcoords) { | |
| 544 labels->clear(); | |
| 545 xcoords->clear(); | |
| 546 const int width = output.Width(); | |
| 547 for (int t = 0; t < width; ++t) { | |
| 548 float score = 0.0f; | |
| 549 const int label = output.BestLabel(t, &score); | |
| 550 if (label != null_char_) { | |
| 551 labels->push_back(label); | |
| 552 xcoords->push_back(t); | |
| 553 } | |
| 554 } | |
| 555 xcoords->push_back(width); | |
| 556 } | |
| 557 | |
| 558 // Returns a string corresponding to the label starting at start. Sets *end | |
| 559 // to the next start and if non-null, *decoded to the unichar id. | |
| 560 const char *LSTMRecognizer::DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, | |
| 561 int *decoded) { | |
| 562 *end = start + 1; | |
| 563 if (IsRecoding()) { | |
| 564 // Decode labels via recoder_. | |
| 565 RecodedCharID code; | |
| 566 if (labels[start] == null_char_) { | |
| 567 if (decoded != nullptr) { | |
| 568 code.Set(0, null_char_); | |
| 569 *decoded = recoder_.DecodeUnichar(code); | |
| 570 } | |
| 571 return "<null>"; | |
| 572 } | |
| 573 unsigned index = start; | |
| 574 while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) { | |
| 575 code.Set(code.length(), labels[index++]); | |
| 576 while (index < labels.size() && labels[index] == null_char_) { | |
| 577 ++index; | |
| 578 } | |
| 579 int uni_id = recoder_.DecodeUnichar(code); | |
| 580 // If the next label isn't a valid first code, then we need to continue | |
| 581 // extending even if we have a valid uni_id from this prefix. | |
| 582 if (uni_id != INVALID_UNICHAR_ID && | |
| 583 (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen || | |
| 584 recoder_.IsValidFirstCode(labels[index]))) { | |
| 585 *end = index; | |
| 586 if (decoded != nullptr) { | |
| 587 *decoded = uni_id; | |
| 588 } | |
| 589 if (uni_id == UNICHAR_SPACE) { | |
| 590 return " "; | |
| 591 } | |
| 592 return GetUnicharset().get_normed_unichar(uni_id); | |
| 593 } | |
| 594 } | |
| 595 return "<Undecodable>"; | |
| 596 } else { | |
| 597 if (decoded != nullptr) { | |
| 598 *decoded = labels[start]; | |
| 599 } | |
| 600 if (labels[start] == null_char_) { | |
| 601 return "<null>"; | |
| 602 } | |
| 603 if (labels[start] == UNICHAR_SPACE) { | |
| 604 return " "; | |
| 605 } | |
| 606 return GetUnicharset().get_normed_unichar(labels[start]); | |
| 607 } | |
| 608 } | |
| 609 | |
| 610 // Returns a string corresponding to a given single label id, falling back to | |
| 611 // a default of ".." for part of a multi-label unichar-id. | |
| 612 const char *LSTMRecognizer::DecodeSingleLabel(int label) { | |
| 613 if (label == null_char_) { | |
| 614 return "<null>"; | |
| 615 } | |
| 616 if (IsRecoding()) { | |
| 617 // Decode label via recoder_. | |
| 618 RecodedCharID code; | |
| 619 code.Set(0, label); | |
| 620 label = recoder_.DecodeUnichar(code); | |
| 621 if (label == INVALID_UNICHAR_ID) { | |
| 622 return ".."; // Part of a bigger code. | |
| 623 } | |
| 624 } | |
| 625 if (label == UNICHAR_SPACE) { | |
| 626 return " "; | |
| 627 } | |
| 628 return GetUnicharset().get_normed_unichar(label); | |
| 629 } | |
| 630 | |
| 631 } // namespace tesseract. |
