Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstmtrainer.cpp | |
| 3 // Description: Top-level line trainer class for LSTM-based networks. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2 | |
| 19 | |
| 20 // Include automatically generated configuration file if running autoconf. | |
| 21 #ifdef HAVE_CONFIG_H | |
| 22 # include "config_auto.h" | |
| 23 #endif | |
| 24 | |
| 25 #include <cmath> | |
| 26 #include <iomanip> // for std::setprecision | |
| 27 #include <locale> // for std::locale::classic | |
| 28 #include <string> | |
| 29 #include "lstmtrainer.h" | |
| 30 | |
| 31 #include <allheaders.h> | |
| 32 #include "boxread.h" | |
| 33 #include "ctc.h" | |
| 34 #include "imagedata.h" | |
| 35 #include "input.h" | |
| 36 #include "networkbuilder.h" | |
| 37 #include "ratngs.h" | |
| 38 #include "recodebeam.h" | |
| 39 #include "tprintf.h" | |
| 40 | |
| 41 namespace tesseract { | |
| 42 | |
| 43 // Min actual error rate increase to constitute divergence. | |
| 44 const double kMinDivergenceRate = 50.0; | |
| 45 // Min iterations since last best before acting on a stall. | |
| 46 const int kMinStallIterations = 10000; | |
| 47 // Fraction of current char error rate that sub_trainer_ has to be ahead | |
| 48 // before we declare the sub_trainer_ a success and switch to it. | |
| 49 const double kSubTrainerMarginFraction = 3.0 / 128; | |
| 50 // Factor to reduce learning rate on divergence. | |
| 51 const double kLearningRateDecay = M_SQRT1_2; | |
| 52 // LR adjustment iterations. | |
| 53 const int kNumAdjustmentIterations = 100; | |
| 54 // How often to add data to the error_graph_. | |
| 55 const int kErrorGraphInterval = 1000; | |
| 56 // Number of training images to train between calls to MaintainCheckpoints. | |
| 57 const int kNumPagesPerBatch = 100; | |
| 58 // Min percent error rate to consider start-up phase over. | |
| 59 const int kMinStartedErrorRate = 75; | |
| 60 // Error rate at which to transition to stage 1. | |
| 61 const double kStageTransitionThreshold = 10.0; | |
| 62 // Confidence beyond which the truth is more likely wrong than the recognizer. | |
| 63 const double kHighConfidence = 0.9375; // 15/16. | |
| 64 // Fraction of weight sign-changing total to constitute a definite improvement. | |
| 65 const double kImprovementFraction = 15.0 / 16.0; | |
| 66 // Fraction of last written best to make it worth writing another. | |
| 67 const double kBestCheckpointFraction = 31.0 / 32.0; | |
| 68 #ifndef GRAPHICS_DISABLED | |
| 69 // Scale factor for display of target activations of CTC. | |
| 70 const int kTargetXScale = 5; | |
| 71 const int kTargetYScale = 100; | |
| 72 #endif // !GRAPHICS_DISABLED | |
| 73 | |
| 74 LSTMTrainer::LSTMTrainer() | |
| 75 : randomly_rotate_(false), training_data_(0), sub_trainer_(nullptr) { | |
| 76 EmptyConstructor(); | |
| 77 debug_interval_ = 0; | |
| 78 } | |
| 79 | |
| 80 LSTMTrainer::LSTMTrainer(const std::string &model_base, const std::string &checkpoint_name, | |
| 81 int debug_interval, int64_t max_memory) | |
| 82 : randomly_rotate_(false), | |
| 83 training_data_(max_memory), | |
| 84 sub_trainer_(nullptr) { | |
| 85 EmptyConstructor(); | |
| 86 debug_interval_ = debug_interval; | |
| 87 model_base_ = model_base; | |
| 88 checkpoint_name_ = checkpoint_name; | |
| 89 } | |
| 90 | |
| 91 LSTMTrainer::~LSTMTrainer() { | |
| 92 #ifndef GRAPHICS_DISABLED | |
| 93 delete align_win_; | |
| 94 delete target_win_; | |
| 95 delete ctc_win_; | |
| 96 delete recon_win_; | |
| 97 #endif | |
| 98 } | |
| 99 | |
| 100 // Tries to deserialize a trainer from the given file and silently returns | |
| 101 // false in case of failure. | |
| 102 bool LSTMTrainer::TryLoadingCheckpoint(const char *filename, | |
| 103 const char *old_traineddata) { | |
| 104 std::vector<char> data; | |
| 105 if (!LoadDataFromFile(filename, &data)) { | |
| 106 return false; | |
| 107 } | |
| 108 tprintf("Loaded file %s, unpacking...\n", filename); | |
| 109 if (!ReadTrainingDump(data, *this)) { | |
| 110 return false; | |
| 111 } | |
| 112 if (IsIntMode()) { | |
| 113 tprintf("Error, %s is an integer (fast) model, cannot continue training\n", | |
| 114 filename); | |
| 115 return false; | |
| 116 } | |
| 117 if (((old_traineddata == nullptr || *old_traineddata == '\0') && | |
| 118 network_->NumOutputs() == recoder_.code_range()) || | |
| 119 filename == old_traineddata) { | |
| 120 return true; // Normal checkpoint load complete. | |
| 121 } | |
| 122 tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(), | |
| 123 recoder_.code_range()); | |
| 124 if (old_traineddata == nullptr || *old_traineddata == '\0') { | |
| 125 tprintf("Must supply the old traineddata for code conversion!\n"); | |
| 126 return false; | |
| 127 } | |
| 128 TessdataManager old_mgr; | |
| 129 ASSERT_HOST(old_mgr.Init(old_traineddata)); | |
| 130 TFile fp; | |
| 131 if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) { | |
| 132 return false; | |
| 133 } | |
| 134 UNICHARSET old_chset; | |
| 135 if (!old_chset.load_from_file(&fp, false)) { | |
| 136 return false; | |
| 137 } | |
| 138 if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) { | |
| 139 return false; | |
| 140 } | |
| 141 UnicharCompress old_recoder; | |
| 142 if (!old_recoder.DeSerialize(&fp)) { | |
| 143 return false; | |
| 144 } | |
| 145 std::vector<int> code_map = MapRecoder(old_chset, old_recoder); | |
| 146 // Set the null_char_ to the new value. | |
| 147 int old_null_char = null_char_; | |
| 148 SetNullChar(); | |
| 149 // Map the softmax(s) in the network. | |
| 150 network_->RemapOutputs(old_recoder.code_range(), code_map); | |
| 151 tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_); | |
| 152 return true; | |
| 153 } | |
| 154 | |
| 155 // Initializes the trainer with a network_spec in the network description | |
| 156 // net_flags control network behavior according to the NetworkFlags enum. | |
| 157 // There isn't really much difference between them - only where the effects | |
| 158 // are implemented. | |
| 159 // For other args see NetworkBuilder::InitNetwork. | |
| 160 // Note: Be sure to call InitCharSet before InitNetwork! | |
| 161 bool LSTMTrainer::InitNetwork(const char *network_spec, int append_index, | |
| 162 int net_flags, float weight_range, | |
| 163 float learning_rate, float momentum, | |
| 164 float adam_beta) { | |
| 165 mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec); | |
| 166 adam_beta_ = adam_beta; | |
| 167 learning_rate_ = learning_rate; | |
| 168 momentum_ = momentum; | |
| 169 SetNullChar(); | |
| 170 if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec, | |
| 171 append_index, net_flags, weight_range, | |
| 172 &randomizer_, &network_)) { | |
| 173 return false; | |
| 174 } | |
| 175 network_str_ += network_spec; | |
| 176 tprintf("Built network:%s from request %s\n", network_->spec().c_str(), | |
| 177 network_spec); | |
| 178 tprintf( | |
| 179 "Training parameters:\n Debug interval = %d," | |
| 180 " weights = %g, learning rate = %g, momentum=%g\n", | |
| 181 debug_interval_, weight_range, learning_rate_, momentum_); | |
| 182 tprintf("null char=%d\n", null_char_); | |
| 183 return true; | |
| 184 } | |
| 185 | |
| 186 // Resets all the iteration counters for fine tuning or traininng a head, | |
| 187 // where we want the error reporting to reset. | |
| 188 void LSTMTrainer::InitIterations() { | |
| 189 sample_iteration_ = 0; | |
| 190 training_iteration_ = 0; | |
| 191 learning_iteration_ = 0; | |
| 192 prev_sample_iteration_ = 0; | |
| 193 best_error_rate_ = 100.0; | |
| 194 best_iteration_ = 0; | |
| 195 worst_error_rate_ = 0.0; | |
| 196 worst_iteration_ = 0; | |
| 197 stall_iteration_ = kMinStallIterations; | |
| 198 best_error_history_.clear(); | |
| 199 best_error_iterations_.clear(); | |
| 200 improvement_steps_ = kMinStallIterations; | |
| 201 perfect_delay_ = 0; | |
| 202 last_perfect_training_iteration_ = 0; | |
| 203 for (int i = 0; i < ET_COUNT; ++i) { | |
| 204 best_error_rates_[i] = 100.0; | |
| 205 worst_error_rates_[i] = 0.0; | |
| 206 error_buffers_[i].clear(); | |
| 207 error_buffers_[i].resize(kRollingBufferSize_); | |
| 208 error_rates_[i] = 100.0; | |
| 209 } | |
| 210 error_rate_of_last_saved_best_ = kMinStartedErrorRate; | |
| 211 } | |
| 212 | |
| 213 // If the training sample is usable, grid searches for the optimal | |
| 214 // dict_ratio/cert_offset, and returns the results in a string of space- | |
| 215 // separated triplets of ratio,offset=worderr. | |
| 216 Trainability LSTMTrainer::GridSearchDictParams( | |
| 217 const ImageData *trainingdata, int iteration, double min_dict_ratio, | |
| 218 double dict_ratio_step, double max_dict_ratio, double min_cert_offset, | |
| 219 double cert_offset_step, double max_cert_offset, std::string &results) { | |
| 220 sample_iteration_ = iteration; | |
| 221 NetworkIO fwd_outputs, targets; | |
| 222 Trainability result = | |
| 223 PrepareForBackward(trainingdata, &fwd_outputs, &targets); | |
| 224 if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr) { | |
| 225 return result; | |
| 226 } | |
| 227 | |
| 228 // Encode/decode the truth to get the normalization. | |
| 229 std::vector<int> truth_labels, ocr_labels, xcoords; | |
| 230 ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels)); | |
| 231 // NO-dict error. | |
| 232 RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), | |
| 233 nullptr); | |
| 234 base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, | |
| 235 nullptr); | |
| 236 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords); | |
| 237 std::string truth_text = DecodeLabels(truth_labels); | |
| 238 std::string ocr_text = DecodeLabels(ocr_labels); | |
| 239 double baseline_error = ComputeWordError(&truth_text, &ocr_text); | |
| 240 results += "0,0=" + std::to_string(baseline_error); | |
| 241 | |
| 242 RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_); | |
| 243 for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) { | |
| 244 for (double c = min_cert_offset; c < max_cert_offset; | |
| 245 c += cert_offset_step) { | |
| 246 search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, | |
| 247 nullptr); | |
| 248 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords); | |
| 249 truth_text = DecodeLabels(truth_labels); | |
| 250 ocr_text = DecodeLabels(ocr_labels); | |
| 251 // This is destructive on both strings. | |
| 252 double word_error = ComputeWordError(&truth_text, &ocr_text); | |
| 253 if ((r == min_dict_ratio && c == min_cert_offset) || | |
| 254 !std::isfinite(word_error)) { | |
| 255 std::string t = DecodeLabels(truth_labels); | |
| 256 std::string o = DecodeLabels(ocr_labels); | |
| 257 tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c, | |
| 258 t.c_str(), o.c_str(), word_error, truth_labels[0]); | |
| 259 } | |
| 260 results += " " + std::to_string(r); | |
| 261 results += "," + std::to_string(c); | |
| 262 results += "=" + std::to_string(word_error); | |
| 263 } | |
| 264 } | |
| 265 return result; | |
| 266 } | |
| 267 | |
| 268 // Provides output on the distribution of weight values. | |
| 269 void LSTMTrainer::DebugNetwork() { | |
| 270 network_->DebugWeights(); | |
| 271 } | |
| 272 | |
| 273 // Loads a set of lstmf files that were created using the lstm.train config to | |
| 274 // tesseract into memory ready for training. Returns false if nothing was | |
| 275 // loaded. | |
| 276 bool LSTMTrainer::LoadAllTrainingData(const std::vector<std::string> &filenames, | |
| 277 CachingStrategy cache_strategy, | |
| 278 bool randomly_rotate) { | |
| 279 randomly_rotate_ = randomly_rotate; | |
| 280 training_data_.Clear(); | |
| 281 return training_data_.LoadDocuments(filenames, cache_strategy, | |
| 282 LoadDataFromFile); | |
| 283 } | |
| 284 | |
| 285 // Keeps track of best and locally worst char error_rate and launches tests | |
| 286 // using tester, when a new min or max is reached. | |
| 287 // Writes checkpoints at appropriate times and builds and returns a log message | |
| 288 // to indicate progress. Returns false if nothing interesting happened. | |
| 289 bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester, | |
| 290 std::stringstream &log_msg) { | |
| 291 PrepareLogMsg(log_msg); | |
| 292 double error_rate = CharError(); | |
| 293 int iteration = learning_iteration(); | |
| 294 if (iteration >= stall_iteration_ && | |
| 295 error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) && | |
| 296 best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) { | |
| 297 // It hasn't got any better in a long while, and is a margin worse than the | |
| 298 // best, so go back to the best model and try a different learning rate. | |
| 299 StartSubtrainer(log_msg); | |
| 300 } | |
| 301 SubTrainerResult sub_trainer_result = STR_NONE; | |
| 302 if (sub_trainer_ != nullptr) { | |
| 303 sub_trainer_result = UpdateSubtrainer(log_msg); | |
| 304 if (sub_trainer_result == STR_REPLACED) { | |
| 305 // Reset the inputs, as we have overwritten *this. | |
| 306 error_rate = CharError(); | |
| 307 iteration = learning_iteration(); | |
| 308 PrepareLogMsg(log_msg); | |
| 309 } | |
| 310 } | |
| 311 bool result = true; // Something interesting happened. | |
| 312 std::vector<char> rec_model_data; | |
| 313 if (error_rate < best_error_rate_) { | |
| 314 SaveRecognitionDump(&rec_model_data); | |
| 315 log_msg << " New best BCER = " << error_rate; | |
| 316 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester); | |
| 317 // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it | |
| 318 // just overwrote *this. In either case, we have finished with it. | |
| 319 sub_trainer_.reset(); | |
| 320 stall_iteration_ = learning_iteration() + kMinStallIterations; | |
| 321 if (TransitionTrainingStage(kStageTransitionThreshold)) { | |
| 322 log_msg << " Transitioned to stage " << CurrentTrainingStage(); | |
| 323 } | |
| 324 SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_); | |
| 325 if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) { | |
| 326 std::string best_model_name = DumpFilename(); | |
| 327 if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) { | |
| 328 log_msg << " failed to write best model:"; | |
| 329 } else { | |
| 330 log_msg << " wrote best model:"; | |
| 331 error_rate_of_last_saved_best_ = best_error_rate_; | |
| 332 } | |
| 333 log_msg << best_model_name; | |
| 334 } | |
| 335 } else if (error_rate > worst_error_rate_) { | |
| 336 SaveRecognitionDump(&rec_model_data); | |
| 337 log_msg << " New worst BCER = " << error_rate; | |
| 338 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester); | |
| 339 if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate && | |
| 340 best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) { | |
| 341 // Error rate has ballooned. Go back to the best model. | |
| 342 log_msg << "\nDivergence! "; | |
| 343 // Copy best_trainer_ before reading it, as it will get overwritten. | |
| 344 std::vector<char> revert_data(best_trainer_); | |
| 345 if (ReadTrainingDump(revert_data, *this)) { | |
| 346 LogIterations("Reverted to", log_msg); | |
| 347 ReduceLearningRates(this, log_msg); | |
| 348 } else { | |
| 349 LogIterations("Failed to Revert at", log_msg); | |
| 350 } | |
| 351 // If it fails again, we will wait twice as long before reverting again. | |
| 352 stall_iteration_ = iteration + 2 * (iteration - learning_iteration()); | |
| 353 // Re-save the best trainer with the new learning rates and stall | |
| 354 // iteration. | |
| 355 SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_); | |
| 356 } | |
| 357 } else { | |
| 358 // Something interesting happened only if the sub_trainer_ was trained. | |
| 359 result = sub_trainer_result != STR_NONE; | |
| 360 } | |
| 361 if (checkpoint_name_.length() > 0) { | |
| 362 // Write a current checkpoint. | |
| 363 std::vector<char> checkpoint; | |
| 364 if (!SaveTrainingDump(FULL, *this, &checkpoint) || | |
| 365 !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) { | |
| 366 log_msg << " failed to write checkpoint."; | |
| 367 } else { | |
| 368 log_msg << " wrote checkpoint."; | |
| 369 } | |
| 370 } | |
| 371 return result; | |
| 372 } | |
| 373 | |
| 374 // Builds a string containing a progress message with current error rates. | |
| 375 void LSTMTrainer::PrepareLogMsg(std::stringstream &log_msg) const { | |
| 376 LogIterations("At", log_msg); | |
| 377 log_msg << std::fixed << std::setprecision(3) | |
| 378 << ", mean rms=" << error_rates_[ET_RMS] | |
| 379 << "%, delta=" << error_rates_[ET_DELTA] | |
| 380 << "%, BCER train=" << error_rates_[ET_CHAR_ERROR] | |
| 381 << "%, BWER train=" << error_rates_[ET_WORD_RECERR] | |
| 382 << "%, skip ratio=" << error_rates_[ET_SKIP_RATIO] << "%,"; | |
| 383 } | |
| 384 | |
| 385 // Appends <intro_str> iteration learning_iteration()/training_iteration()/ | |
| 386 // sample_iteration() to the log_msg. | |
| 387 void LSTMTrainer::LogIterations(const char *intro_str, | |
| 388 std::stringstream &log_msg) const { | |
| 389 log_msg << intro_str | |
| 390 << " iteration " << learning_iteration() | |
| 391 << "/" << training_iteration() | |
| 392 << "/" << sample_iteration(); | |
| 393 } | |
| 394 | |
| 395 // Returns true and increments the training_stage_ if the error rate has just | |
| 396 // passed through the given threshold for the first time. | |
| 397 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) { | |
| 398 if (best_error_rate_ < error_threshold && | |
| 399 training_stage_ + 1 < num_training_stages_) { | |
| 400 ++training_stage_; | |
| 401 return true; | |
| 402 } | |
| 403 return false; | |
| 404 } | |
| 405 | |
| 406 // Writes to the given file. Returns false in case of error. | |
| 407 bool LSTMTrainer::Serialize(SerializeAmount serialize_amount, | |
| 408 const TessdataManager *mgr, TFile *fp) const { | |
| 409 if (!LSTMRecognizer::Serialize(mgr, fp)) { | |
| 410 return false; | |
| 411 } | |
| 412 if (!fp->Serialize(&learning_iteration_)) { | |
| 413 return false; | |
| 414 } | |
| 415 if (!fp->Serialize(&prev_sample_iteration_)) { | |
| 416 return false; | |
| 417 } | |
| 418 if (!fp->Serialize(&perfect_delay_)) { | |
| 419 return false; | |
| 420 } | |
| 421 if (!fp->Serialize(&last_perfect_training_iteration_)) { | |
| 422 return false; | |
| 423 } | |
| 424 for (const auto &error_buffer : error_buffers_) { | |
| 425 if (!fp->Serialize(error_buffer)) { | |
| 426 return false; | |
| 427 } | |
| 428 } | |
| 429 if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) { | |
| 430 return false; | |
| 431 } | |
| 432 if (!fp->Serialize(&training_stage_)) { | |
| 433 return false; | |
| 434 } | |
| 435 uint8_t amount = serialize_amount; | |
| 436 if (!fp->Serialize(&amount)) { | |
| 437 return false; | |
| 438 } | |
| 439 if (serialize_amount == LIGHT) { | |
| 440 return true; // We are done. | |
| 441 } | |
| 442 if (!fp->Serialize(&best_error_rate_)) { | |
| 443 return false; | |
| 444 } | |
| 445 if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) { | |
| 446 return false; | |
| 447 } | |
| 448 if (!fp->Serialize(&best_iteration_)) { | |
| 449 return false; | |
| 450 } | |
| 451 if (!fp->Serialize(&worst_error_rate_)) { | |
| 452 return false; | |
| 453 } | |
| 454 if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) { | |
| 455 return false; | |
| 456 } | |
| 457 if (!fp->Serialize(&worst_iteration_)) { | |
| 458 return false; | |
| 459 } | |
| 460 if (!fp->Serialize(&stall_iteration_)) { | |
| 461 return false; | |
| 462 } | |
| 463 if (!fp->Serialize(best_model_data_)) { | |
| 464 return false; | |
| 465 } | |
| 466 if (!fp->Serialize(worst_model_data_)) { | |
| 467 return false; | |
| 468 } | |
| 469 if (serialize_amount != NO_BEST_TRAINER && !fp->Serialize(best_trainer_)) { | |
| 470 return false; | |
| 471 } | |
| 472 std::vector<char> sub_data; | |
| 473 if (sub_trainer_ != nullptr && | |
| 474 !SaveTrainingDump(LIGHT, *sub_trainer_, &sub_data)) { | |
| 475 return false; | |
| 476 } | |
| 477 if (!fp->Serialize(sub_data)) { | |
| 478 return false; | |
| 479 } | |
| 480 if (!fp->Serialize(best_error_history_)) { | |
| 481 return false; | |
| 482 } | |
| 483 if (!fp->Serialize(best_error_iterations_)) { | |
| 484 return false; | |
| 485 } | |
| 486 return fp->Serialize(&improvement_steps_); | |
| 487 } | |
| 488 | |
| 489 // Reads from the given file. Returns false in case of error. | |
| 490 // NOTE: It is assumed that the trainer is never read cross-endian. | |
| 491 bool LSTMTrainer::DeSerialize(const TessdataManager *mgr, TFile *fp) { | |
| 492 if (!LSTMRecognizer::DeSerialize(mgr, fp)) { | |
| 493 return false; | |
| 494 } | |
| 495 if (!fp->DeSerialize(&learning_iteration_)) { | |
| 496 // Special case. If we successfully decoded the recognizer, but fail here | |
| 497 // then it means we were just given a recognizer, so issue a warning and | |
| 498 // allow it. | |
| 499 tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n"); | |
| 500 learning_iteration_ = 0; | |
| 501 network_->SetEnableTraining(TS_ENABLED); | |
| 502 return true; | |
| 503 } | |
| 504 if (!fp->DeSerialize(&prev_sample_iteration_)) { | |
| 505 return false; | |
| 506 } | |
| 507 if (!fp->DeSerialize(&perfect_delay_)) { | |
| 508 return false; | |
| 509 } | |
| 510 if (!fp->DeSerialize(&last_perfect_training_iteration_)) { | |
| 511 return false; | |
| 512 } | |
| 513 for (auto &error_buffer : error_buffers_) { | |
| 514 if (!fp->DeSerialize(error_buffer)) { | |
| 515 return false; | |
| 516 } | |
| 517 } | |
| 518 if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) { | |
| 519 return false; | |
| 520 } | |
| 521 if (!fp->DeSerialize(&training_stage_)) { | |
| 522 return false; | |
| 523 } | |
| 524 uint8_t amount; | |
| 525 if (!fp->DeSerialize(&amount)) { | |
| 526 return false; | |
| 527 } | |
| 528 if (amount == LIGHT) { | |
| 529 return true; // Don't read the rest. | |
| 530 } | |
| 531 if (!fp->DeSerialize(&best_error_rate_)) { | |
| 532 return false; | |
| 533 } | |
| 534 if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) { | |
| 535 return false; | |
| 536 } | |
| 537 if (!fp->DeSerialize(&best_iteration_)) { | |
| 538 return false; | |
| 539 } | |
| 540 if (!fp->DeSerialize(&worst_error_rate_)) { | |
| 541 return false; | |
| 542 } | |
| 543 if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) { | |
| 544 return false; | |
| 545 } | |
| 546 if (!fp->DeSerialize(&worst_iteration_)) { | |
| 547 return false; | |
| 548 } | |
| 549 if (!fp->DeSerialize(&stall_iteration_)) { | |
| 550 return false; | |
| 551 } | |
| 552 if (!fp->DeSerialize(best_model_data_)) { | |
| 553 return false; | |
| 554 } | |
| 555 if (!fp->DeSerialize(worst_model_data_)) { | |
| 556 return false; | |
| 557 } | |
| 558 if (amount != NO_BEST_TRAINER && !fp->DeSerialize(best_trainer_)) { | |
| 559 return false; | |
| 560 } | |
| 561 std::vector<char> sub_data; | |
| 562 if (!fp->DeSerialize(sub_data)) { | |
| 563 return false; | |
| 564 } | |
| 565 if (sub_data.empty()) { | |
| 566 sub_trainer_ = nullptr; | |
| 567 } else { | |
| 568 sub_trainer_ = std::make_unique<LSTMTrainer>(); | |
| 569 if (!ReadTrainingDump(sub_data, *sub_trainer_)) { | |
| 570 return false; | |
| 571 } | |
| 572 } | |
| 573 if (!fp->DeSerialize(best_error_history_)) { | |
| 574 return false; | |
| 575 } | |
| 576 if (!fp->DeSerialize(best_error_iterations_)) { | |
| 577 return false; | |
| 578 } | |
| 579 return fp->DeSerialize(&improvement_steps_); | |
| 580 } | |
| 581 | |
| 582 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the | |
| 583 // learning rates (by scaling reduction, or layer specific, according to | |
| 584 // NF_LAYER_SPECIFIC_LR). | |
| 585 void LSTMTrainer::StartSubtrainer(std::stringstream &log_msg) { | |
| 586 sub_trainer_ = std::make_unique<LSTMTrainer>(); | |
| 587 if (!ReadTrainingDump(best_trainer_, *sub_trainer_)) { | |
| 588 log_msg << " Failed to revert to previous best for trial!"; | |
| 589 sub_trainer_.reset(); | |
| 590 } else { | |
| 591 log_msg << " Trial sub_trainer_ from iteration " | |
| 592 << sub_trainer_->training_iteration(); | |
| 593 // Reduce learning rate so it doesn't diverge this time. | |
| 594 sub_trainer_->ReduceLearningRates(this, log_msg); | |
| 595 // If it fails again, we will wait twice as long before reverting again. | |
| 596 int stall_offset = | |
| 597 learning_iteration() - sub_trainer_->learning_iteration(); | |
| 598 stall_iteration_ = learning_iteration() + 2 * stall_offset; | |
| 599 sub_trainer_->stall_iteration_ = stall_iteration_; | |
| 600 // Re-save the best trainer with the new learning rates and stall iteration. | |
| 601 SaveTrainingDump(NO_BEST_TRAINER, *sub_trainer_, &best_trainer_); | |
| 602 } | |
| 603 } | |
| 604 | |
| 605 // While the sub_trainer_ is behind the current training iteration and its | |
| 606 // training error is at least kSubTrainerMarginFraction better than the | |
| 607 // current training error, trains the sub_trainer_, and returns STR_UPDATED if | |
| 608 // it did anything. If it catches up, and has a better error rate than the | |
| 609 // current best, as well as a margin over the current error rate, then the | |
| 610 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is | |
| 611 // returned. STR_NONE is returned if the subtrainer wasn't good enough to | |
| 612 // receive any training iterations. | |
| 613 SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::stringstream &log_msg) { | |
| 614 double training_error = CharError(); | |
| 615 double sub_error = sub_trainer_->CharError(); | |
| 616 double sub_margin = (training_error - sub_error) / sub_error; | |
| 617 if (sub_margin >= kSubTrainerMarginFraction) { | |
| 618 log_msg << " sub_trainer=" << sub_error | |
| 619 << " margin=" << 100.0 * sub_margin << "\n"; | |
| 620 // Catch up to current iteration. | |
| 621 int end_iteration = training_iteration(); | |
| 622 while (sub_trainer_->training_iteration() < end_iteration && | |
| 623 sub_margin >= kSubTrainerMarginFraction) { | |
| 624 int target_iteration = | |
| 625 sub_trainer_->training_iteration() + kNumPagesPerBatch; | |
| 626 while (sub_trainer_->training_iteration() < target_iteration) { | |
| 627 sub_trainer_->TrainOnLine(this, false); | |
| 628 } | |
| 629 std::stringstream batch_log("Sub:"); | |
| 630 batch_log.imbue(std::locale::classic()); | |
| 631 sub_trainer_->PrepareLogMsg(batch_log); | |
| 632 batch_log << "\n"; | |
| 633 tprintf("UpdateSubtrainer:%s", batch_log.str().c_str()); | |
| 634 log_msg << batch_log.str(); | |
| 635 sub_error = sub_trainer_->CharError(); | |
| 636 sub_margin = (training_error - sub_error) / sub_error; | |
| 637 } | |
| 638 if (sub_error < best_error_rate_ && | |
| 639 sub_margin >= kSubTrainerMarginFraction) { | |
| 640 // The sub_trainer_ has won the race to a new best. Switch to it. | |
| 641 std::vector<char> updated_trainer; | |
| 642 SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer); | |
| 643 ReadTrainingDump(updated_trainer, *this); | |
| 644 log_msg << " Sub trainer wins at iteration " | |
| 645 << training_iteration() << "\n"; | |
| 646 return STR_REPLACED; | |
| 647 } | |
| 648 return STR_UPDATED; | |
| 649 } | |
| 650 return STR_NONE; | |
| 651 } | |
| 652 | |
| 653 // Reduces network learning rates, either for everything, or for layers | |
| 654 // independently, according to NF_LAYER_SPECIFIC_LR. | |
| 655 void LSTMTrainer::ReduceLearningRates(LSTMTrainer *samples_trainer, | |
| 656 std::stringstream &log_msg) { | |
| 657 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { | |
| 658 int num_reduced = ReduceLayerLearningRates( | |
| 659 kLearningRateDecay, kNumAdjustmentIterations, samples_trainer); | |
| 660 log_msg << "\nReduced learning rate on layers: " << num_reduced; | |
| 661 } else { | |
| 662 ScaleLearningRate(kLearningRateDecay); | |
| 663 log_msg << "\nReduced learning rate to :" << learning_rate_; | |
| 664 } | |
| 665 log_msg << "\n"; | |
| 666 } | |
| 667 | |
| 668 // Considers reducing the learning rate independently for each layer down by | |
| 669 // factor(<1), or leaving it the same, by double-training the given number of | |
| 670 // samples and minimizing the amount of changing of sign of weight updates. | |
| 671 // Even if it looks like all weights should remain the same, an adjustment | |
| 672 // will be made to guarantee a different result when reverting to an old best. | |
| 673 // Returns the number of layer learning rates that were reduced. | |
| 674 int LSTMTrainer::ReduceLayerLearningRates(TFloat factor, int num_samples, | |
| 675 LSTMTrainer *samples_trainer) { | |
| 676 enum WhichWay { | |
| 677 LR_DOWN, // Learning rate will go down by factor. | |
| 678 LR_SAME, // Learning rate will stay the same. | |
| 679 LR_COUNT // Size of arrays. | |
| 680 }; | |
| 681 std::vector<std::string> layers = EnumerateLayers(); | |
| 682 int num_layers = layers.size(); | |
| 683 std::vector<int> num_weights(num_layers); | |
| 684 std::vector<TFloat> bad_sums[LR_COUNT]; | |
| 685 std::vector<TFloat> ok_sums[LR_COUNT]; | |
| 686 for (int i = 0; i < LR_COUNT; ++i) { | |
| 687 bad_sums[i].resize(num_layers, 0.0); | |
| 688 ok_sums[i].resize(num_layers, 0.0); | |
| 689 } | |
| 690 auto momentum_factor = 1 / (1 - momentum_); | |
| 691 std::vector<char> orig_trainer; | |
| 692 samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer); | |
| 693 for (int i = 0; i < num_layers; ++i) { | |
| 694 Network *layer = GetLayer(layers[i]); | |
| 695 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0; | |
| 696 } | |
| 697 int iteration = sample_iteration(); | |
| 698 for (int s = 0; s < num_samples; ++s) { | |
| 699 // Which way will we modify the learning rate? | |
| 700 for (int ww = 0; ww < LR_COUNT; ++ww) { | |
| 701 // Transfer momentum to learning rate and adjust by the ww factor. | |
| 702 auto ww_factor = momentum_factor; | |
| 703 if (ww == LR_DOWN) { | |
| 704 ww_factor *= factor; | |
| 705 } | |
| 706 // Make a copy of *this, so we can mess about without damaging anything. | |
| 707 LSTMTrainer copy_trainer; | |
| 708 samples_trainer->ReadTrainingDump(orig_trainer, copy_trainer); | |
| 709 // Clear the updates, doing nothing else. | |
| 710 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0); | |
| 711 // Adjust the learning rate in each layer. | |
| 712 for (int i = 0; i < num_layers; ++i) { | |
| 713 if (num_weights[i] == 0) { | |
| 714 continue; | |
| 715 } | |
| 716 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor); | |
| 717 } | |
| 718 copy_trainer.SetIteration(iteration); | |
| 719 // Train on the sample, but keep the update in updates_ instead of | |
| 720 // applying to the weights. | |
| 721 const ImageData *trainingdata = | |
| 722 copy_trainer.TrainOnLine(samples_trainer, true); | |
| 723 if (trainingdata == nullptr) { | |
| 724 continue; | |
| 725 } | |
| 726 // We'll now use this trainer again for each layer. | |
| 727 std::vector<char> updated_trainer; | |
| 728 samples_trainer->SaveTrainingDump(LIGHT, copy_trainer, &updated_trainer); | |
| 729 for (int i = 0; i < num_layers; ++i) { | |
| 730 if (num_weights[i] == 0) { | |
| 731 continue; | |
| 732 } | |
| 733 LSTMTrainer layer_trainer; | |
| 734 samples_trainer->ReadTrainingDump(updated_trainer, layer_trainer); | |
| 735 Network *layer = layer_trainer.GetLayer(layers[i]); | |
| 736 // Update the weights in just the layer, using Adam if enabled. | |
| 737 layer->Update(0.0, momentum_, adam_beta_, | |
| 738 layer_trainer.training_iteration_ + 1); | |
| 739 // Zero the updates matrix again. | |
| 740 layer->Update(0.0, 0.0, 0.0, 0); | |
| 741 // Train again on the same sample, again holding back the updates. | |
| 742 layer_trainer.TrainOnLine(trainingdata, true); | |
| 743 // Count the sign changes in the updates in layer vs in copy_trainer. | |
| 744 float before_bad = bad_sums[ww][i]; | |
| 745 float before_ok = ok_sums[ww][i]; | |
| 746 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]), | |
| 747 &ok_sums[ww][i], &bad_sums[ww][i]); | |
| 748 float bad_frac = | |
| 749 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok; | |
| 750 if (bad_frac > 0.0f) { | |
| 751 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac; | |
| 752 } | |
| 753 } | |
| 754 } | |
| 755 ++iteration; | |
| 756 } | |
| 757 int num_lowered = 0; | |
| 758 for (int i = 0; i < num_layers; ++i) { | |
| 759 if (num_weights[i] == 0) { | |
| 760 continue; | |
| 761 } | |
| 762 Network *layer = GetLayer(layers[i]); | |
| 763 float lr = GetLayerLearningRate(layers[i]); | |
| 764 TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i]; | |
| 765 TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i]; | |
| 766 TFloat frac_down = bad_sums[LR_DOWN][i] / total_down; | |
| 767 TFloat frac_same = bad_sums[LR_SAME][i] / total_same; | |
| 768 tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(), | |
| 769 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same); | |
| 770 if (frac_down < frac_same * kImprovementFraction) { | |
| 771 tprintf(" REDUCED\n"); | |
| 772 ScaleLayerLearningRate(layers[i], factor); | |
| 773 ++num_lowered; | |
| 774 } else { | |
| 775 tprintf(" SAME\n"); | |
| 776 } | |
| 777 } | |
| 778 if (num_lowered == 0) { | |
| 779 // Just lower everything to make sure. | |
| 780 for (int i = 0; i < num_layers; ++i) { | |
| 781 if (num_weights[i] > 0) { | |
| 782 ScaleLayerLearningRate(layers[i], factor); | |
| 783 ++num_lowered; | |
| 784 } | |
| 785 } | |
| 786 } | |
| 787 return num_lowered; | |
| 788 } | |
| 789 | |
| 790 // Converts the string to integer class labels, with appropriate null_char_s | |
| 791 // in between if not in SimpleTextOutput mode. Returns false on failure. | |
| 792 /* static */ | |
| 793 bool LSTMTrainer::EncodeString(const std::string &str, | |
| 794 const UNICHARSET &unicharset, | |
| 795 const UnicharCompress *recoder, bool simple_text, | |
| 796 int null_char, std::vector<int> *labels) { | |
| 797 if (str.c_str() == nullptr || str.length() <= 0) { | |
| 798 tprintf("Empty truth string!\n"); | |
| 799 return false; | |
| 800 } | |
| 801 unsigned err_index; | |
| 802 std::vector<int> internal_labels; | |
| 803 labels->clear(); | |
| 804 if (!simple_text) { | |
| 805 labels->push_back(null_char); | |
| 806 } | |
| 807 std::string cleaned = unicharset.CleanupString(str.c_str()); | |
| 808 if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr, | |
| 809 &err_index)) { | |
| 810 bool success = true; | |
| 811 for (auto internal_label : internal_labels) { | |
| 812 if (recoder != nullptr) { | |
| 813 // Re-encode labels via recoder. | |
| 814 RecodedCharID code; | |
| 815 int len = recoder->EncodeUnichar(internal_label, &code); | |
| 816 if (len > 0) { | |
| 817 for (int j = 0; j < len; ++j) { | |
| 818 labels->push_back(code(j)); | |
| 819 if (!simple_text) { | |
| 820 labels->push_back(null_char); | |
| 821 } | |
| 822 } | |
| 823 } else { | |
| 824 success = false; | |
| 825 err_index = 0; | |
| 826 break; | |
| 827 } | |
| 828 } else { | |
| 829 labels->push_back(internal_label); | |
| 830 if (!simple_text) { | |
| 831 labels->push_back(null_char); | |
| 832 } | |
| 833 } | |
| 834 } | |
| 835 if (success) { | |
| 836 return true; | |
| 837 } | |
| 838 } | |
| 839 tprintf("Encoding of string failed! Failure bytes:"); | |
| 840 while (err_index < cleaned.size()) { | |
| 841 tprintf(" %x", cleaned[err_index++] & 0xff); | |
| 842 } | |
| 843 tprintf("\n"); | |
| 844 return false; | |
| 845 } | |
| 846 | |
| 847 // Performs forward-backward on the given trainingdata. | |
| 848 // Returns a Trainability enum to indicate the suitability of the sample. | |
| 849 Trainability LSTMTrainer::TrainOnLine(const ImageData *trainingdata, | |
| 850 bool batch) { | |
| 851 NetworkIO fwd_outputs, targets; | |
| 852 Trainability trainable = | |
| 853 PrepareForBackward(trainingdata, &fwd_outputs, &targets); | |
| 854 ++sample_iteration_; | |
| 855 if (trainable == UNENCODABLE || trainable == NOT_BOXED) { | |
| 856 return trainable; // Sample was unusable. | |
| 857 } | |
| 858 bool debug = | |
| 859 debug_interval_ > 0 && training_iteration() % debug_interval_ == 0; | |
| 860 // Run backprop on the output. | |
| 861 NetworkIO bp_deltas; | |
| 862 if (network_->IsTraining() && | |
| 863 (trainable != PERFECT || | |
| 864 training_iteration() > | |
| 865 last_perfect_training_iteration_ + perfect_delay_)) { | |
| 866 network_->Backward(debug, targets, &scratch_space_, &bp_deltas); | |
| 867 network_->Update(learning_rate_, batch ? -1.0f : momentum_, adam_beta_, | |
| 868 training_iteration_ + 1); | |
| 869 } | |
| 870 #ifndef GRAPHICS_DISABLED | |
| 871 if (debug_interval_ == 1 && debug_win_ != nullptr) { | |
| 872 debug_win_->AwaitEvent(SVET_CLICK); | |
| 873 } | |
| 874 #endif // !GRAPHICS_DISABLED | |
| 875 // Roll the memory of past means. | |
| 876 RollErrorBuffers(); | |
| 877 return trainable; | |
| 878 } | |
| 879 | |
| 880 // Prepares the ground truth, runs forward, and prepares the targets. | |
| 881 // Returns a Trainability enum to indicate the suitability of the sample. | |
| 882 Trainability LSTMTrainer::PrepareForBackward(const ImageData *trainingdata, | |
| 883 NetworkIO *fwd_outputs, | |
| 884 NetworkIO *targets) { | |
| 885 if (trainingdata == nullptr) { | |
| 886 tprintf("Null trainingdata.\n"); | |
| 887 return UNENCODABLE; | |
| 888 } | |
| 889 // Ensure repeatability of random elements even across checkpoints. | |
| 890 bool debug = | |
| 891 debug_interval_ > 0 && training_iteration() % debug_interval_ == 0; | |
| 892 std::vector<int> truth_labels; | |
| 893 if (!EncodeString(trainingdata->transcription(), &truth_labels)) { | |
| 894 tprintf("Can't encode transcription: '%s' in language '%s'\n", | |
| 895 trainingdata->transcription().c_str(), | |
| 896 trainingdata->language().c_str()); | |
| 897 return UNENCODABLE; | |
| 898 } | |
| 899 bool upside_down = false; | |
| 900 if (randomly_rotate_) { | |
| 901 // This ensures consistent training results. | |
| 902 SetRandomSeed(); | |
| 903 upside_down = randomizer_.SignedRand(1.0) > 0.0; | |
| 904 if (upside_down) { | |
| 905 // Modify the truth labels to match the rotation: | |
| 906 // Apart from space and null, increment the label. This changes the | |
| 907 // script-id to the same script-id but upside-down. | |
| 908 // The labels need to be reversed in order, as the first is now the last. | |
| 909 for (auto truth_label : truth_labels) { | |
| 910 if (truth_label != UNICHAR_SPACE && truth_label != null_char_) { | |
| 911 ++truth_label; | |
| 912 } | |
| 913 } | |
| 914 std::reverse(truth_labels.begin(), truth_labels.end()); | |
| 915 } | |
| 916 } | |
| 917 unsigned w = 0; | |
| 918 while (w < truth_labels.size() && | |
| 919 (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) { | |
| 920 ++w; | |
| 921 } | |
| 922 if (w == truth_labels.size()) { | |
| 923 tprintf("Blank transcription: %s\n", trainingdata->transcription().c_str()); | |
| 924 return UNENCODABLE; | |
| 925 } | |
| 926 float image_scale; | |
| 927 NetworkIO inputs; | |
| 928 bool invert = trainingdata->boxes().empty(); | |
| 929 if (!RecognizeLine(*trainingdata, invert ? 0.5f : 0.0f, debug, invert, upside_down, | |
| 930 &image_scale, &inputs, fwd_outputs)) { | |
| 931 tprintf("Image %s not trainable\n", trainingdata->imagefilename().c_str()); | |
| 932 return UNENCODABLE; | |
| 933 } | |
| 934 targets->Resize(*fwd_outputs, network_->NumOutputs()); | |
| 935 LossType loss_type = OutputLossType(); | |
| 936 if (loss_type == LT_SOFTMAX) { | |
| 937 if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) { | |
| 938 tprintf("Compute simple targets failed for %s!\n", | |
| 939 trainingdata->imagefilename().c_str()); | |
| 940 return UNENCODABLE; | |
| 941 } | |
| 942 } else if (loss_type == LT_CTC) { | |
| 943 if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) { | |
| 944 tprintf("Compute CTC targets failed for %s!\n", | |
| 945 trainingdata->imagefilename().c_str()); | |
| 946 return UNENCODABLE; | |
| 947 } | |
| 948 } else { | |
| 949 tprintf("Logistic outputs not implemented yet!\n"); | |
| 950 return UNENCODABLE; | |
| 951 } | |
| 952 std::vector<int> ocr_labels; | |
| 953 std::vector<int> xcoords; | |
| 954 LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords); | |
| 955 // CTC does not produce correct target labels to begin with. | |
| 956 if (loss_type != LT_CTC) { | |
| 957 LabelsFromOutputs(*targets, &truth_labels, &xcoords); | |
| 958 } | |
| 959 if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels, | |
| 960 *targets)) { | |
| 961 tprintf("Input width was %d\n", inputs.Width()); | |
| 962 return UNENCODABLE; | |
| 963 } | |
| 964 std::string ocr_text = DecodeLabels(ocr_labels); | |
| 965 std::string truth_text = DecodeLabels(truth_labels); | |
| 966 targets->SubtractAllFromFloat(*fwd_outputs); | |
| 967 if (debug_interval_ != 0) { | |
| 968 if (truth_text != ocr_text) { | |
| 969 tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(), | |
| 970 ocr_text.c_str()); | |
| 971 } | |
| 972 } | |
| 973 double char_error = ComputeCharError(truth_labels, ocr_labels); | |
| 974 double word_error = ComputeWordError(&truth_text, &ocr_text); | |
| 975 double delta_error = ComputeErrorRates(*targets, char_error, word_error); | |
| 976 if (debug_interval_ != 0) { | |
| 977 tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(), | |
| 978 trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : ""); | |
| 979 } | |
| 980 if (delta_error == 0.0) { | |
| 981 return PERFECT; | |
| 982 } | |
| 983 if (targets->AnySuspiciousTruth(kHighConfidence)) { | |
| 984 return HI_PRECISION_ERR; | |
| 985 } | |
| 986 return TRAINABLE; | |
| 987 } | |
| 988 | |
| 989 // Writes the trainer to memory, so that the current training state can be | |
| 990 // restored. *this must always be the master trainer that retains the only | |
| 991 // copy of the training data and language model. trainer is the model that is | |
| 992 // actually serialized. | |
| 993 bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount, | |
| 994 const LSTMTrainer &trainer, | |
| 995 std::vector<char> *data) const { | |
| 996 TFile fp; | |
| 997 fp.OpenWrite(data); | |
| 998 return trainer.Serialize(serialize_amount, &mgr_, &fp); | |
| 999 } | |
| 1000 | |
| 1001 // Restores the model to *this. | |
| 1002 bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager *mgr, | |
| 1003 const char *data, int size) { | |
| 1004 if (size == 0) { | |
| 1005 tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n"); | |
| 1006 return false; | |
| 1007 } | |
| 1008 TFile fp; | |
| 1009 fp.Open(data, size); | |
| 1010 return DeSerialize(mgr, &fp); | |
| 1011 } | |
| 1012 | |
| 1013 // Writes the full recognition traineddata to the given filename. | |
| 1014 bool LSTMTrainer::SaveTraineddata(const char *filename) { | |
| 1015 std::vector<char> recognizer_data; | |
| 1016 SaveRecognitionDump(&recognizer_data); | |
| 1017 mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0], | |
| 1018 recognizer_data.size()); | |
| 1019 return mgr_.SaveFile(filename, SaveDataToFile); | |
| 1020 } | |
| 1021 | |
| 1022 // Writes the recognizer to memory, so that it can be used for testing later. | |
| 1023 void LSTMTrainer::SaveRecognitionDump(std::vector<char> *data) const { | |
| 1024 TFile fp; | |
| 1025 fp.OpenWrite(data); | |
| 1026 network_->SetEnableTraining(TS_TEMP_DISABLE); | |
| 1027 ASSERT_HOST(LSTMRecognizer::Serialize(&mgr_, &fp)); | |
| 1028 network_->SetEnableTraining(TS_RE_ENABLE); | |
| 1029 } | |
| 1030 | |
| 1031 // Returns a suitable filename for a training dump, based on the model_base_, | |
| 1032 // best_error_rate_, best_iteration_ and training_iteration_. | |
| 1033 std::string LSTMTrainer::DumpFilename() const { | |
| 1034 std::stringstream filename; | |
| 1035 filename.imbue(std::locale::classic()); | |
| 1036 filename << model_base_ << std::fixed << std::setprecision(3) | |
| 1037 << "_" << best_error_rate_ | |
| 1038 << "_" << best_iteration_ | |
| 1039 << "_" << training_iteration_ | |
| 1040 << ".checkpoint"; | |
| 1041 return filename.str(); | |
| 1042 } | |
| 1043 | |
| 1044 // Fills the whole error buffer of the given type with the given value. | |
| 1045 void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) { | |
| 1046 for (int i = 0; i < kRollingBufferSize_; ++i) { | |
| 1047 error_buffers_[type][i] = new_error; | |
| 1048 } | |
| 1049 error_rates_[type] = 100.0 * new_error; | |
| 1050 } | |
| 1051 | |
| 1052 // Helper generates a map from each current recoder_ code (ie softmax index) | |
| 1053 // to the corresponding old_recoder code, or -1 if there isn't one. | |
| 1054 std::vector<int> LSTMTrainer::MapRecoder( | |
| 1055 const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const { | |
| 1056 int num_new_codes = recoder_.code_range(); | |
| 1057 int num_new_unichars = GetUnicharset().size(); | |
| 1058 std::vector<int> code_map(num_new_codes, -1); | |
| 1059 for (int c = 0; c < num_new_codes; ++c) { | |
| 1060 int old_code = -1; | |
| 1061 // Find all new unichar_ids that recode to something that includes c. | |
| 1062 // The <= is to include the null char, which may be beyond the unicharset. | |
| 1063 for (int uid = 0; uid <= num_new_unichars; ++uid) { | |
| 1064 RecodedCharID codes; | |
| 1065 int length = recoder_.EncodeUnichar(uid, &codes); | |
| 1066 int code_index = 0; | |
| 1067 while (code_index < length && codes(code_index) != c) { | |
| 1068 ++code_index; | |
| 1069 } | |
| 1070 if (code_index == length) { | |
| 1071 continue; | |
| 1072 } | |
| 1073 // The old unicharset must have the same unichar. | |
| 1074 int old_uid = | |
| 1075 uid < num_new_unichars | |
| 1076 ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid)) | |
| 1077 : old_chset.size() - 1; | |
| 1078 if (old_uid == INVALID_UNICHAR_ID) { | |
| 1079 continue; | |
| 1080 } | |
| 1081 // The encoding of old_uid at the same code_index is the old code. | |
| 1082 RecodedCharID old_codes; | |
| 1083 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) { | |
| 1084 old_code = old_codes(code_index); | |
| 1085 break; | |
| 1086 } | |
| 1087 } | |
| 1088 code_map[c] = old_code; | |
| 1089 } | |
| 1090 return code_map; | |
| 1091 } | |
| 1092 | |
| 1093 // Private version of InitCharSet above finishes the job after initializing | |
| 1094 // the mgr_ data member. | |
| 1095 void LSTMTrainer::InitCharSet() { | |
| 1096 EmptyConstructor(); | |
| 1097 training_flags_ = TF_COMPRESS_UNICHARSET; | |
| 1098 // Initialize the unicharset and recoder. | |
| 1099 if (!LoadCharsets(&mgr_)) { | |
| 1100 ASSERT_HOST( | |
| 1101 "Must provide a traineddata containing lstm_unicharset and" | |
| 1102 " lstm_recoder!\n" != nullptr); | |
| 1103 } | |
| 1104 SetNullChar(); | |
| 1105 } | |
| 1106 | |
| 1107 // Helper computes and sets the null_char_. | |
| 1108 void LSTMTrainer::SetNullChar() { | |
| 1109 null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN | |
| 1110 : GetUnicharset().size(); | |
| 1111 RecodedCharID code; | |
| 1112 recoder_.EncodeUnichar(null_char_, &code); | |
| 1113 null_char_ = code(0); | |
| 1114 } | |
| 1115 | |
| 1116 // Factored sub-constructor sets up reasonable default values. | |
| 1117 void LSTMTrainer::EmptyConstructor() { | |
| 1118 #ifndef GRAPHICS_DISABLED | |
| 1119 align_win_ = nullptr; | |
| 1120 target_win_ = nullptr; | |
| 1121 ctc_win_ = nullptr; | |
| 1122 recon_win_ = nullptr; | |
| 1123 #endif | |
| 1124 checkpoint_iteration_ = 0; | |
| 1125 training_stage_ = 0; | |
| 1126 num_training_stages_ = 2; | |
| 1127 InitIterations(); | |
| 1128 } | |
| 1129 | |
| 1130 // Outputs the string and periodically displays the given network inputs | |
| 1131 // as an image in the given window, and the corresponding labels at the | |
| 1132 // corresponding x_starts. | |
| 1133 // Returns false if the truth string is empty. | |
| 1134 bool LSTMTrainer::DebugLSTMTraining(const NetworkIO &inputs, | |
| 1135 const ImageData &trainingdata, | |
| 1136 const NetworkIO &fwd_outputs, | |
| 1137 const std::vector<int> &truth_labels, | |
| 1138 const NetworkIO &outputs) { | |
| 1139 const std::string &truth_text = DecodeLabels(truth_labels); | |
| 1140 if (truth_text.c_str() == nullptr || truth_text.length() <= 0) { | |
| 1141 tprintf("Empty truth string at decode time!\n"); | |
| 1142 return false; | |
| 1143 } | |
| 1144 if (debug_interval_ != 0) { | |
| 1145 // Get class labels, xcoords and string. | |
| 1146 std::vector<int> labels; | |
| 1147 std::vector<int> xcoords; | |
| 1148 LabelsFromOutputs(outputs, &labels, &xcoords); | |
| 1149 std::string text = DecodeLabels(labels); | |
| 1150 tprintf("Iteration %d: GROUND TRUTH : %s\n", training_iteration(), | |
| 1151 truth_text.c_str()); | |
| 1152 if (truth_text != text) { | |
| 1153 tprintf("Iteration %d: ALIGNED TRUTH : %s\n", training_iteration(), | |
| 1154 text.c_str()); | |
| 1155 } | |
| 1156 if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) { | |
| 1157 tprintf("TRAINING activation path for truth string %s\n", | |
| 1158 truth_text.c_str()); | |
| 1159 DebugActivationPath(outputs, labels, xcoords); | |
| 1160 #ifndef GRAPHICS_DISABLED | |
| 1161 DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_); | |
| 1162 if (OutputLossType() == LT_CTC) { | |
| 1163 DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_); | |
| 1164 DisplayTargets(outputs, "CTC Targets", &target_win_); | |
| 1165 } | |
| 1166 #endif | |
| 1167 } | |
| 1168 } | |
| 1169 return true; | |
| 1170 } | |
| 1171 | |
| 1172 #ifndef GRAPHICS_DISABLED | |
| 1173 | |
| 1174 // Displays the network targets as line a line graph. | |
| 1175 void LSTMTrainer::DisplayTargets(const NetworkIO &targets, | |
| 1176 const char *window_name, ScrollView **window) { | |
| 1177 int width = targets.Width(); | |
| 1178 int num_features = targets.NumFeatures(); | |
| 1179 Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale, | |
| 1180 window); | |
| 1181 for (int c = 0; c < num_features; ++c) { | |
| 1182 int color = c % (ScrollView::GREEN_YELLOW - 1) + 2; | |
| 1183 (*window)->Pen(static_cast<ScrollView::Color>(color)); | |
| 1184 int start_t = -1; | |
| 1185 for (int t = 0; t < width; ++t) { | |
| 1186 double target = targets.f(t)[c]; | |
| 1187 target *= kTargetYScale; | |
| 1188 if (target >= 1) { | |
| 1189 if (start_t < 0) { | |
| 1190 (*window)->SetCursor(t - 1, 0); | |
| 1191 start_t = t; | |
| 1192 } | |
| 1193 (*window)->DrawTo(t, target); | |
| 1194 } else if (start_t >= 0) { | |
| 1195 (*window)->DrawTo(t, 0); | |
| 1196 (*window)->DrawTo(start_t - 1, 0); | |
| 1197 start_t = -1; | |
| 1198 } | |
| 1199 } | |
| 1200 if (start_t >= 0) { | |
| 1201 (*window)->DrawTo(width, 0); | |
| 1202 (*window)->DrawTo(start_t - 1, 0); | |
| 1203 } | |
| 1204 } | |
| 1205 (*window)->Update(); | |
| 1206 } | |
| 1207 | |
| 1208 #endif // !GRAPHICS_DISABLED | |
| 1209 | |
| 1210 // Builds a no-compromises target where the first positions should be the | |
| 1211 // truth labels and the rest is padded with the null_char_. | |
| 1212 bool LSTMTrainer::ComputeTextTargets(const NetworkIO &outputs, | |
| 1213 const std::vector<int> &truth_labels, | |
| 1214 NetworkIO *targets) { | |
| 1215 if (truth_labels.size() > targets->Width()) { | |
| 1216 tprintf("Error: transcription %s too long to fit into target of width %d\n", | |
| 1217 DecodeLabels(truth_labels).c_str(), targets->Width()); | |
| 1218 return false; | |
| 1219 } | |
| 1220 int i = 0; | |
| 1221 for (auto truth_label : truth_labels) { | |
| 1222 targets->SetActivations(i, truth_label, 1.0); | |
| 1223 ++i; | |
| 1224 } | |
| 1225 for (i = truth_labels.size(); i < targets->Width(); ++i) { | |
| 1226 targets->SetActivations(i, null_char_, 1.0); | |
| 1227 } | |
| 1228 return true; | |
| 1229 } | |
| 1230 | |
| 1231 // Builds a target using standard CTC. truth_labels should be pre-padded with | |
| 1232 // nulls wherever desired. They don't have to be between all labels. | |
| 1233 // outputs is input-output, as it gets clipped to minimum probability. | |
| 1234 bool LSTMTrainer::ComputeCTCTargets(const std::vector<int> &truth_labels, | |
| 1235 NetworkIO *outputs, NetworkIO *targets) { | |
| 1236 // Bottom-clip outputs to a minimum probability. | |
| 1237 CTC::NormalizeProbs(outputs); | |
| 1238 return CTC::ComputeCTCTargets(truth_labels, null_char_, | |
| 1239 outputs->float_array(), targets); | |
| 1240 } | |
| 1241 | |
| 1242 // Computes network errors, and stores the results in the rolling buffers, | |
| 1243 // along with the supplied text_error. | |
| 1244 // Returns the delta error of the current sample (not running average.) | |
| 1245 double LSTMTrainer::ComputeErrorRates(const NetworkIO &deltas, | |
| 1246 double char_error, double word_error) { | |
| 1247 UpdateErrorBuffer(ComputeRMSError(deltas), ET_RMS); | |
| 1248 // Delta error is the fraction of timesteps with >0.5 error in the top choice | |
| 1249 // score. If zero, then the top choice characters are guaranteed correct, | |
| 1250 // even when there is residue in the RMS error. | |
| 1251 double delta_error = ComputeWinnerError(deltas); | |
| 1252 UpdateErrorBuffer(delta_error, ET_DELTA); | |
| 1253 UpdateErrorBuffer(word_error, ET_WORD_RECERR); | |
| 1254 UpdateErrorBuffer(char_error, ET_CHAR_ERROR); | |
| 1255 // Skip ratio measures the difference between sample_iteration_ and | |
| 1256 // training_iteration_, which reflects the number of unusable samples, | |
| 1257 // usually due to unencodable truth text, or the text not fitting in the | |
| 1258 // space for the output. | |
| 1259 double skip_count = sample_iteration_ - prev_sample_iteration_; | |
| 1260 UpdateErrorBuffer(skip_count, ET_SKIP_RATIO); | |
| 1261 return delta_error; | |
| 1262 } | |
| 1263 | |
| 1264 // Computes the network activation RMS error rate. | |
| 1265 double LSTMTrainer::ComputeRMSError(const NetworkIO &deltas) { | |
| 1266 double total_error = 0.0; | |
| 1267 int width = deltas.Width(); | |
| 1268 int num_classes = deltas.NumFeatures(); | |
| 1269 for (int t = 0; t < width; ++t) { | |
| 1270 const float *class_errs = deltas.f(t); | |
| 1271 for (int c = 0; c < num_classes; ++c) { | |
| 1272 double error = class_errs[c]; | |
| 1273 total_error += error * error; | |
| 1274 } | |
| 1275 } | |
| 1276 return sqrt(total_error / (width * num_classes)); | |
| 1277 } | |
| 1278 | |
| 1279 // Computes network activation winner error rate. (Number of values that are | |
| 1280 // in error by >= 0.5 divided by number of time-steps.) More closely related | |
| 1281 // to final character error than RMS, but still directly calculable from | |
| 1282 // just the deltas. Because of the binary nature of the targets, zero winner | |
| 1283 // error is a sufficient but not necessary condition for zero char error. | |
| 1284 double LSTMTrainer::ComputeWinnerError(const NetworkIO &deltas) { | |
| 1285 int num_errors = 0; | |
| 1286 int width = deltas.Width(); | |
| 1287 int num_classes = deltas.NumFeatures(); | |
| 1288 for (int t = 0; t < width; ++t) { | |
| 1289 const float *class_errs = deltas.f(t); | |
| 1290 for (int c = 0; c < num_classes; ++c) { | |
| 1291 float abs_delta = std::fabs(class_errs[c]); | |
| 1292 // TODO(rays) Filtering cases where the delta is very large to cut out | |
| 1293 // GT errors doesn't work. Find a better way or get better truth. | |
| 1294 if (0.5 <= abs_delta) { | |
| 1295 ++num_errors; | |
| 1296 } | |
| 1297 } | |
| 1298 } | |
| 1299 return static_cast<double>(num_errors) / width; | |
| 1300 } | |
| 1301 | |
| 1302 // Computes a very simple bag of chars char error rate. | |
| 1303 double LSTMTrainer::ComputeCharError(const std::vector<int> &truth_str, | |
| 1304 const std::vector<int> &ocr_str) { | |
| 1305 std::vector<int> label_counts(NumOutputs()); | |
| 1306 unsigned truth_size = 0; | |
| 1307 for (auto ch : truth_str) { | |
| 1308 if (ch != null_char_) { | |
| 1309 ++label_counts[ch]; | |
| 1310 ++truth_size; | |
| 1311 } | |
| 1312 } | |
| 1313 for (auto ch : ocr_str) { | |
| 1314 if (ch != null_char_) { | |
| 1315 --label_counts[ch]; | |
| 1316 } | |
| 1317 } | |
| 1318 unsigned char_errors = 0; | |
| 1319 for (auto label_count : label_counts) { | |
| 1320 char_errors += abs(label_count); | |
| 1321 } | |
| 1322 // Limit BCER to interval [0,1] and avoid division by zero. | |
| 1323 if (truth_size <= char_errors) { | |
| 1324 return (char_errors == 0) ? 0.0 : 1.0; | |
| 1325 } | |
| 1326 return static_cast<double>(char_errors) / truth_size; | |
| 1327 } | |
| 1328 | |
| 1329 // Computes word recall error rate using a very simple bag of words algorithm. | |
| 1330 // NOTE that this is destructive on both input strings. | |
| 1331 double LSTMTrainer::ComputeWordError(std::string *truth_str, | |
| 1332 std::string *ocr_str) { | |
| 1333 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>; | |
| 1334 std::vector<std::string> truth_words = split(*truth_str, ' '); | |
| 1335 if (truth_words.empty()) { | |
| 1336 return 0.0; | |
| 1337 } | |
| 1338 std::vector<std::string> ocr_words = split(*ocr_str, ' '); | |
| 1339 StrMap word_counts; | |
| 1340 for (const auto &truth_word : truth_words) { | |
| 1341 std::string truth_word_string(truth_word.c_str()); | |
| 1342 auto it = word_counts.find(truth_word_string); | |
| 1343 if (it == word_counts.end()) { | |
| 1344 word_counts.insert(std::make_pair(truth_word_string, 1)); | |
| 1345 } else { | |
| 1346 ++it->second; | |
| 1347 } | |
| 1348 } | |
| 1349 for (const auto &ocr_word : ocr_words) { | |
| 1350 std::string ocr_word_string(ocr_word.c_str()); | |
| 1351 auto it = word_counts.find(ocr_word_string); | |
| 1352 if (it == word_counts.end()) { | |
| 1353 word_counts.insert(std::make_pair(ocr_word_string, -1)); | |
| 1354 } else { | |
| 1355 --it->second; | |
| 1356 } | |
| 1357 } | |
| 1358 int word_recall_errs = 0; | |
| 1359 for (const auto &word_count : word_counts) { | |
| 1360 if (word_count.second > 0) { | |
| 1361 word_recall_errs += word_count.second; | |
| 1362 } | |
| 1363 } | |
| 1364 return static_cast<double>(word_recall_errs) / truth_words.size(); | |
| 1365 } | |
| 1366 | |
| 1367 // Updates the error buffer and corresponding mean of the given type with | |
| 1368 // the new_error. | |
| 1369 void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) { | |
| 1370 int index = training_iteration_ % kRollingBufferSize_; | |
| 1371 error_buffers_[type][index] = new_error; | |
| 1372 // Compute the mean error. | |
| 1373 int mean_count = | |
| 1374 std::min<int>(training_iteration_ + 1, error_buffers_[type].size()); | |
| 1375 double buffer_sum = 0.0; | |
| 1376 for (int i = 0; i < mean_count; ++i) { | |
| 1377 buffer_sum += error_buffers_[type][i]; | |
| 1378 } | |
| 1379 double mean = buffer_sum / mean_count; | |
| 1380 // Trim precision to 1/1000 of 1%. | |
| 1381 error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0; | |
| 1382 } | |
| 1383 | |
| 1384 // Rolls error buffers and reports the current means. | |
| 1385 void LSTMTrainer::RollErrorBuffers() { | |
| 1386 prev_sample_iteration_ = sample_iteration_; | |
| 1387 if (NewSingleError(ET_DELTA) > 0.0) { | |
| 1388 ++learning_iteration_; | |
| 1389 } else { | |
| 1390 last_perfect_training_iteration_ = training_iteration_; | |
| 1391 } | |
| 1392 ++training_iteration_; | |
| 1393 if (debug_interval_ != 0) { | |
| 1394 tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n", | |
| 1395 error_rates_[ET_RMS], error_rates_[ET_DELTA], | |
| 1396 error_rates_[ET_CHAR_ERROR], error_rates_[ET_WORD_RECERR], | |
| 1397 error_rates_[ET_SKIP_RATIO]); | |
| 1398 } | |
| 1399 } | |
| 1400 | |
| 1401 // Given that error_rate is either a new min or max, updates the best/worst | |
| 1402 // error rates, and record of progress. | |
| 1403 // Tester is an externally supplied callback function that tests on some | |
| 1404 // data set with a given model and records the error rates in a graph. | |
| 1405 std::string LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, | |
| 1406 const std::vector<char> &model_data, | |
| 1407 const TestCallback &tester) { | |
| 1408 if (error_rate > best_error_rate_ && | |
| 1409 iteration < best_iteration_ + kErrorGraphInterval) { | |
| 1410 // Too soon to record a new point. | |
| 1411 if (tester != nullptr && !worst_model_data_.empty()) { | |
| 1412 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], | |
| 1413 worst_model_data_.size()); | |
| 1414 return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage()); | |
| 1415 } else { | |
| 1416 return ""; | |
| 1417 } | |
| 1418 } | |
| 1419 std::string result; | |
| 1420 // NOTE: there are 2 asymmetries here: | |
| 1421 // 1. We are computing the global minimum, but the local maximum in between. | |
| 1422 // 2. If the tester returns an empty string, indicating that it is busy, | |
| 1423 // call it repeatedly on new local maxima to test the previous min, but | |
| 1424 // not the other way around, as there is little point testing the maxima | |
| 1425 // between very frequent minima. | |
| 1426 if (error_rate < best_error_rate_) { | |
| 1427 // This is a new (global) minimum. | |
| 1428 if (tester != nullptr && !worst_model_data_.empty()) { | |
| 1429 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], | |
| 1430 worst_model_data_.size()); | |
| 1431 result = tester(worst_iteration_, worst_error_rates_, mgr_, | |
| 1432 CurrentTrainingStage()); | |
| 1433 worst_model_data_.clear(); | |
| 1434 best_model_data_ = model_data; | |
| 1435 } | |
| 1436 best_error_rate_ = error_rate; | |
| 1437 memcpy(best_error_rates_, error_rates_, sizeof(error_rates_)); | |
| 1438 best_iteration_ = iteration; | |
| 1439 best_error_history_.push_back(error_rate); | |
| 1440 best_error_iterations_.push_back(iteration); | |
| 1441 // Compute 2% decay time. | |
| 1442 double two_percent_more = error_rate + 2.0; | |
| 1443 int i; | |
| 1444 for (i = best_error_history_.size() - 1; | |
| 1445 i >= 0 && best_error_history_[i] < two_percent_more; --i) { | |
| 1446 } | |
| 1447 int old_iteration = i >= 0 ? best_error_iterations_[i] : 0; | |
| 1448 improvement_steps_ = iteration - old_iteration; | |
| 1449 tprintf("2 Percent improvement time=%d, best error was %g @ %d\n", | |
| 1450 improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0, | |
| 1451 old_iteration); | |
| 1452 } else if (error_rate > best_error_rate_) { | |
| 1453 // This is a new (local) maximum. | |
| 1454 if (tester != nullptr) { | |
| 1455 if (!best_model_data_.empty()) { | |
| 1456 mgr_.OverwriteEntry(TESSDATA_LSTM, &best_model_data_[0], | |
| 1457 best_model_data_.size()); | |
| 1458 result = tester(best_iteration_, best_error_rates_, mgr_, | |
| 1459 CurrentTrainingStage()); | |
| 1460 } else if (!worst_model_data_.empty()) { | |
| 1461 // Allow for multiple data points with "worst" error rate. | |
| 1462 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], | |
| 1463 worst_model_data_.size()); | |
| 1464 result = tester(worst_iteration_, worst_error_rates_, mgr_, | |
| 1465 CurrentTrainingStage()); | |
| 1466 } | |
| 1467 if (result.length() > 0) { | |
| 1468 best_model_data_.clear(); | |
| 1469 } | |
| 1470 worst_model_data_ = model_data; | |
| 1471 } | |
| 1472 } | |
| 1473 worst_error_rate_ = error_rate; | |
| 1474 memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_)); | |
| 1475 worst_iteration_ = iteration; | |
| 1476 return result; | |
| 1477 } | |
| 1478 | |
| 1479 } // namespace tesseract. |
