comparison mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: lstmtrainer.cpp
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19
20 // Include automatically generated configuration file if running autoconf.
21 #ifdef HAVE_CONFIG_H
22 # include "config_auto.h"
23 #endif
24
25 #include <cmath>
26 #include <iomanip> // for std::setprecision
27 #include <locale> // for std::locale::classic
28 #include <string>
29 #include "lstmtrainer.h"
30
31 #include <allheaders.h>
32 #include "boxread.h"
33 #include "ctc.h"
34 #include "imagedata.h"
35 #include "input.h"
36 #include "networkbuilder.h"
37 #include "ratngs.h"
38 #include "recodebeam.h"
39 #include "tprintf.h"
40
41 namespace tesseract {
42
43 // Min actual error rate increase to constitute divergence.
44 const double kMinDivergenceRate = 50.0;
45 // Min iterations since last best before acting on a stall.
46 const int kMinStallIterations = 10000;
47 // Fraction of current char error rate that sub_trainer_ has to be ahead
48 // before we declare the sub_trainer_ a success and switch to it.
49 const double kSubTrainerMarginFraction = 3.0 / 128;
50 // Factor to reduce learning rate on divergence.
51 const double kLearningRateDecay = M_SQRT1_2;
52 // LR adjustment iterations.
53 const int kNumAdjustmentIterations = 100;
54 // How often to add data to the error_graph_.
55 const int kErrorGraphInterval = 1000;
56 // Number of training images to train between calls to MaintainCheckpoints.
57 const int kNumPagesPerBatch = 100;
58 // Min percent error rate to consider start-up phase over.
59 const int kMinStartedErrorRate = 75;
60 // Error rate at which to transition to stage 1.
61 const double kStageTransitionThreshold = 10.0;
62 // Confidence beyond which the truth is more likely wrong than the recognizer.
63 const double kHighConfidence = 0.9375; // 15/16.
64 // Fraction of weight sign-changing total to constitute a definite improvement.
65 const double kImprovementFraction = 15.0 / 16.0;
66 // Fraction of last written best to make it worth writing another.
67 const double kBestCheckpointFraction = 31.0 / 32.0;
68 #ifndef GRAPHICS_DISABLED
69 // Scale factor for display of target activations of CTC.
70 const int kTargetXScale = 5;
71 const int kTargetYScale = 100;
72 #endif // !GRAPHICS_DISABLED
73
74 LSTMTrainer::LSTMTrainer()
75 : randomly_rotate_(false), training_data_(0), sub_trainer_(nullptr) {
76 EmptyConstructor();
77 debug_interval_ = 0;
78 }
79
80 LSTMTrainer::LSTMTrainer(const std::string &model_base, const std::string &checkpoint_name,
81 int debug_interval, int64_t max_memory)
82 : randomly_rotate_(false),
83 training_data_(max_memory),
84 sub_trainer_(nullptr) {
85 EmptyConstructor();
86 debug_interval_ = debug_interval;
87 model_base_ = model_base;
88 checkpoint_name_ = checkpoint_name;
89 }
90
91 LSTMTrainer::~LSTMTrainer() {
92 #ifndef GRAPHICS_DISABLED
93 delete align_win_;
94 delete target_win_;
95 delete ctc_win_;
96 delete recon_win_;
97 #endif
98 }
99
100 // Tries to deserialize a trainer from the given file and silently returns
101 // false in case of failure.
102 bool LSTMTrainer::TryLoadingCheckpoint(const char *filename,
103 const char *old_traineddata) {
104 std::vector<char> data;
105 if (!LoadDataFromFile(filename, &data)) {
106 return false;
107 }
108 tprintf("Loaded file %s, unpacking...\n", filename);
109 if (!ReadTrainingDump(data, *this)) {
110 return false;
111 }
112 if (IsIntMode()) {
113 tprintf("Error, %s is an integer (fast) model, cannot continue training\n",
114 filename);
115 return false;
116 }
117 if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
118 network_->NumOutputs() == recoder_.code_range()) ||
119 filename == old_traineddata) {
120 return true; // Normal checkpoint load complete.
121 }
122 tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
123 recoder_.code_range());
124 if (old_traineddata == nullptr || *old_traineddata == '\0') {
125 tprintf("Must supply the old traineddata for code conversion!\n");
126 return false;
127 }
128 TessdataManager old_mgr;
129 ASSERT_HOST(old_mgr.Init(old_traineddata));
130 TFile fp;
131 if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
132 return false;
133 }
134 UNICHARSET old_chset;
135 if (!old_chset.load_from_file(&fp, false)) {
136 return false;
137 }
138 if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
139 return false;
140 }
141 UnicharCompress old_recoder;
142 if (!old_recoder.DeSerialize(&fp)) {
143 return false;
144 }
145 std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
146 // Set the null_char_ to the new value.
147 int old_null_char = null_char_;
148 SetNullChar();
149 // Map the softmax(s) in the network.
150 network_->RemapOutputs(old_recoder.code_range(), code_map);
151 tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
152 return true;
153 }
154
155 // Initializes the trainer with a network_spec in the network description
156 // net_flags control network behavior according to the NetworkFlags enum.
157 // There isn't really much difference between them - only where the effects
158 // are implemented.
159 // For other args see NetworkBuilder::InitNetwork.
160 // Note: Be sure to call InitCharSet before InitNetwork!
161 bool LSTMTrainer::InitNetwork(const char *network_spec, int append_index,
162 int net_flags, float weight_range,
163 float learning_rate, float momentum,
164 float adam_beta) {
165 mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec);
166 adam_beta_ = adam_beta;
167 learning_rate_ = learning_rate;
168 momentum_ = momentum;
169 SetNullChar();
170 if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
171 append_index, net_flags, weight_range,
172 &randomizer_, &network_)) {
173 return false;
174 }
175 network_str_ += network_spec;
176 tprintf("Built network:%s from request %s\n", network_->spec().c_str(),
177 network_spec);
178 tprintf(
179 "Training parameters:\n Debug interval = %d,"
180 " weights = %g, learning rate = %g, momentum=%g\n",
181 debug_interval_, weight_range, learning_rate_, momentum_);
182 tprintf("null char=%d\n", null_char_);
183 return true;
184 }
185
186 // Resets all the iteration counters for fine tuning or traininng a head,
187 // where we want the error reporting to reset.
188 void LSTMTrainer::InitIterations() {
189 sample_iteration_ = 0;
190 training_iteration_ = 0;
191 learning_iteration_ = 0;
192 prev_sample_iteration_ = 0;
193 best_error_rate_ = 100.0;
194 best_iteration_ = 0;
195 worst_error_rate_ = 0.0;
196 worst_iteration_ = 0;
197 stall_iteration_ = kMinStallIterations;
198 best_error_history_.clear();
199 best_error_iterations_.clear();
200 improvement_steps_ = kMinStallIterations;
201 perfect_delay_ = 0;
202 last_perfect_training_iteration_ = 0;
203 for (int i = 0; i < ET_COUNT; ++i) {
204 best_error_rates_[i] = 100.0;
205 worst_error_rates_[i] = 0.0;
206 error_buffers_[i].clear();
207 error_buffers_[i].resize(kRollingBufferSize_);
208 error_rates_[i] = 100.0;
209 }
210 error_rate_of_last_saved_best_ = kMinStartedErrorRate;
211 }
212
213 // If the training sample is usable, grid searches for the optimal
214 // dict_ratio/cert_offset, and returns the results in a string of space-
215 // separated triplets of ratio,offset=worderr.
216 Trainability LSTMTrainer::GridSearchDictParams(
217 const ImageData *trainingdata, int iteration, double min_dict_ratio,
218 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
219 double cert_offset_step, double max_cert_offset, std::string &results) {
220 sample_iteration_ = iteration;
221 NetworkIO fwd_outputs, targets;
222 Trainability result =
223 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
224 if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr) {
225 return result;
226 }
227
228 // Encode/decode the truth to get the normalization.
229 std::vector<int> truth_labels, ocr_labels, xcoords;
230 ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
231 // NO-dict error.
232 RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(),
233 nullptr);
234 base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
235 nullptr);
236 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
237 std::string truth_text = DecodeLabels(truth_labels);
238 std::string ocr_text = DecodeLabels(ocr_labels);
239 double baseline_error = ComputeWordError(&truth_text, &ocr_text);
240 results += "0,0=" + std::to_string(baseline_error);
241
242 RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_);
243 for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
244 for (double c = min_cert_offset; c < max_cert_offset;
245 c += cert_offset_step) {
246 search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty,
247 nullptr);
248 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
249 truth_text = DecodeLabels(truth_labels);
250 ocr_text = DecodeLabels(ocr_labels);
251 // This is destructive on both strings.
252 double word_error = ComputeWordError(&truth_text, &ocr_text);
253 if ((r == min_dict_ratio && c == min_cert_offset) ||
254 !std::isfinite(word_error)) {
255 std::string t = DecodeLabels(truth_labels);
256 std::string o = DecodeLabels(ocr_labels);
257 tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
258 t.c_str(), o.c_str(), word_error, truth_labels[0]);
259 }
260 results += " " + std::to_string(r);
261 results += "," + std::to_string(c);
262 results += "=" + std::to_string(word_error);
263 }
264 }
265 return result;
266 }
267
268 // Provides output on the distribution of weight values.
269 void LSTMTrainer::DebugNetwork() {
270 network_->DebugWeights();
271 }
272
273 // Loads a set of lstmf files that were created using the lstm.train config to
274 // tesseract into memory ready for training. Returns false if nothing was
275 // loaded.
276 bool LSTMTrainer::LoadAllTrainingData(const std::vector<std::string> &filenames,
277 CachingStrategy cache_strategy,
278 bool randomly_rotate) {
279 randomly_rotate_ = randomly_rotate;
280 training_data_.Clear();
281 return training_data_.LoadDocuments(filenames, cache_strategy,
282 LoadDataFromFile);
283 }
284
285 // Keeps track of best and locally worst char error_rate and launches tests
286 // using tester, when a new min or max is reached.
287 // Writes checkpoints at appropriate times and builds and returns a log message
288 // to indicate progress. Returns false if nothing interesting happened.
289 bool LSTMTrainer::MaintainCheckpoints(const TestCallback &tester,
290 std::stringstream &log_msg) {
291 PrepareLogMsg(log_msg);
292 double error_rate = CharError();
293 int iteration = learning_iteration();
294 if (iteration >= stall_iteration_ &&
295 error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
296 best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
297 // It hasn't got any better in a long while, and is a margin worse than the
298 // best, so go back to the best model and try a different learning rate.
299 StartSubtrainer(log_msg);
300 }
301 SubTrainerResult sub_trainer_result = STR_NONE;
302 if (sub_trainer_ != nullptr) {
303 sub_trainer_result = UpdateSubtrainer(log_msg);
304 if (sub_trainer_result == STR_REPLACED) {
305 // Reset the inputs, as we have overwritten *this.
306 error_rate = CharError();
307 iteration = learning_iteration();
308 PrepareLogMsg(log_msg);
309 }
310 }
311 bool result = true; // Something interesting happened.
312 std::vector<char> rec_model_data;
313 if (error_rate < best_error_rate_) {
314 SaveRecognitionDump(&rec_model_data);
315 log_msg << " New best BCER = " << error_rate;
316 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
317 // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
318 // just overwrote *this. In either case, we have finished with it.
319 sub_trainer_.reset();
320 stall_iteration_ = learning_iteration() + kMinStallIterations;
321 if (TransitionTrainingStage(kStageTransitionThreshold)) {
322 log_msg << " Transitioned to stage " << CurrentTrainingStage();
323 }
324 SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_);
325 if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
326 std::string best_model_name = DumpFilename();
327 if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
328 log_msg << " failed to write best model:";
329 } else {
330 log_msg << " wrote best model:";
331 error_rate_of_last_saved_best_ = best_error_rate_;
332 }
333 log_msg << best_model_name;
334 }
335 } else if (error_rate > worst_error_rate_) {
336 SaveRecognitionDump(&rec_model_data);
337 log_msg << " New worst BCER = " << error_rate;
338 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
339 if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate &&
340 best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
341 // Error rate has ballooned. Go back to the best model.
342 log_msg << "\nDivergence! ";
343 // Copy best_trainer_ before reading it, as it will get overwritten.
344 std::vector<char> revert_data(best_trainer_);
345 if (ReadTrainingDump(revert_data, *this)) {
346 LogIterations("Reverted to", log_msg);
347 ReduceLearningRates(this, log_msg);
348 } else {
349 LogIterations("Failed to Revert at", log_msg);
350 }
351 // If it fails again, we will wait twice as long before reverting again.
352 stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
353 // Re-save the best trainer with the new learning rates and stall
354 // iteration.
355 SaveTrainingDump(NO_BEST_TRAINER, *this, &best_trainer_);
356 }
357 } else {
358 // Something interesting happened only if the sub_trainer_ was trained.
359 result = sub_trainer_result != STR_NONE;
360 }
361 if (checkpoint_name_.length() > 0) {
362 // Write a current checkpoint.
363 std::vector<char> checkpoint;
364 if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
365 !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
366 log_msg << " failed to write checkpoint.";
367 } else {
368 log_msg << " wrote checkpoint.";
369 }
370 }
371 return result;
372 }
373
374 // Builds a string containing a progress message with current error rates.
375 void LSTMTrainer::PrepareLogMsg(std::stringstream &log_msg) const {
376 LogIterations("At", log_msg);
377 log_msg << std::fixed << std::setprecision(3)
378 << ", mean rms=" << error_rates_[ET_RMS]
379 << "%, delta=" << error_rates_[ET_DELTA]
380 << "%, BCER train=" << error_rates_[ET_CHAR_ERROR]
381 << "%, BWER train=" << error_rates_[ET_WORD_RECERR]
382 << "%, skip ratio=" << error_rates_[ET_SKIP_RATIO] << "%,";
383 }
384
385 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
386 // sample_iteration() to the log_msg.
387 void LSTMTrainer::LogIterations(const char *intro_str,
388 std::stringstream &log_msg) const {
389 log_msg << intro_str
390 << " iteration " << learning_iteration()
391 << "/" << training_iteration()
392 << "/" << sample_iteration();
393 }
394
395 // Returns true and increments the training_stage_ if the error rate has just
396 // passed through the given threshold for the first time.
397 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
398 if (best_error_rate_ < error_threshold &&
399 training_stage_ + 1 < num_training_stages_) {
400 ++training_stage_;
401 return true;
402 }
403 return false;
404 }
405
406 // Writes to the given file. Returns false in case of error.
407 bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
408 const TessdataManager *mgr, TFile *fp) const {
409 if (!LSTMRecognizer::Serialize(mgr, fp)) {
410 return false;
411 }
412 if (!fp->Serialize(&learning_iteration_)) {
413 return false;
414 }
415 if (!fp->Serialize(&prev_sample_iteration_)) {
416 return false;
417 }
418 if (!fp->Serialize(&perfect_delay_)) {
419 return false;
420 }
421 if (!fp->Serialize(&last_perfect_training_iteration_)) {
422 return false;
423 }
424 for (const auto &error_buffer : error_buffers_) {
425 if (!fp->Serialize(error_buffer)) {
426 return false;
427 }
428 }
429 if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) {
430 return false;
431 }
432 if (!fp->Serialize(&training_stage_)) {
433 return false;
434 }
435 uint8_t amount = serialize_amount;
436 if (!fp->Serialize(&amount)) {
437 return false;
438 }
439 if (serialize_amount == LIGHT) {
440 return true; // We are done.
441 }
442 if (!fp->Serialize(&best_error_rate_)) {
443 return false;
444 }
445 if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) {
446 return false;
447 }
448 if (!fp->Serialize(&best_iteration_)) {
449 return false;
450 }
451 if (!fp->Serialize(&worst_error_rate_)) {
452 return false;
453 }
454 if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) {
455 return false;
456 }
457 if (!fp->Serialize(&worst_iteration_)) {
458 return false;
459 }
460 if (!fp->Serialize(&stall_iteration_)) {
461 return false;
462 }
463 if (!fp->Serialize(best_model_data_)) {
464 return false;
465 }
466 if (!fp->Serialize(worst_model_data_)) {
467 return false;
468 }
469 if (serialize_amount != NO_BEST_TRAINER && !fp->Serialize(best_trainer_)) {
470 return false;
471 }
472 std::vector<char> sub_data;
473 if (sub_trainer_ != nullptr &&
474 !SaveTrainingDump(LIGHT, *sub_trainer_, &sub_data)) {
475 return false;
476 }
477 if (!fp->Serialize(sub_data)) {
478 return false;
479 }
480 if (!fp->Serialize(best_error_history_)) {
481 return false;
482 }
483 if (!fp->Serialize(best_error_iterations_)) {
484 return false;
485 }
486 return fp->Serialize(&improvement_steps_);
487 }
488
489 // Reads from the given file. Returns false in case of error.
490 // NOTE: It is assumed that the trainer is never read cross-endian.
491 bool LSTMTrainer::DeSerialize(const TessdataManager *mgr, TFile *fp) {
492 if (!LSTMRecognizer::DeSerialize(mgr, fp)) {
493 return false;
494 }
495 if (!fp->DeSerialize(&learning_iteration_)) {
496 // Special case. If we successfully decoded the recognizer, but fail here
497 // then it means we were just given a recognizer, so issue a warning and
498 // allow it.
499 tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
500 learning_iteration_ = 0;
501 network_->SetEnableTraining(TS_ENABLED);
502 return true;
503 }
504 if (!fp->DeSerialize(&prev_sample_iteration_)) {
505 return false;
506 }
507 if (!fp->DeSerialize(&perfect_delay_)) {
508 return false;
509 }
510 if (!fp->DeSerialize(&last_perfect_training_iteration_)) {
511 return false;
512 }
513 for (auto &error_buffer : error_buffers_) {
514 if (!fp->DeSerialize(error_buffer)) {
515 return false;
516 }
517 }
518 if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) {
519 return false;
520 }
521 if (!fp->DeSerialize(&training_stage_)) {
522 return false;
523 }
524 uint8_t amount;
525 if (!fp->DeSerialize(&amount)) {
526 return false;
527 }
528 if (amount == LIGHT) {
529 return true; // Don't read the rest.
530 }
531 if (!fp->DeSerialize(&best_error_rate_)) {
532 return false;
533 }
534 if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) {
535 return false;
536 }
537 if (!fp->DeSerialize(&best_iteration_)) {
538 return false;
539 }
540 if (!fp->DeSerialize(&worst_error_rate_)) {
541 return false;
542 }
543 if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) {
544 return false;
545 }
546 if (!fp->DeSerialize(&worst_iteration_)) {
547 return false;
548 }
549 if (!fp->DeSerialize(&stall_iteration_)) {
550 return false;
551 }
552 if (!fp->DeSerialize(best_model_data_)) {
553 return false;
554 }
555 if (!fp->DeSerialize(worst_model_data_)) {
556 return false;
557 }
558 if (amount != NO_BEST_TRAINER && !fp->DeSerialize(best_trainer_)) {
559 return false;
560 }
561 std::vector<char> sub_data;
562 if (!fp->DeSerialize(sub_data)) {
563 return false;
564 }
565 if (sub_data.empty()) {
566 sub_trainer_ = nullptr;
567 } else {
568 sub_trainer_ = std::make_unique<LSTMTrainer>();
569 if (!ReadTrainingDump(sub_data, *sub_trainer_)) {
570 return false;
571 }
572 }
573 if (!fp->DeSerialize(best_error_history_)) {
574 return false;
575 }
576 if (!fp->DeSerialize(best_error_iterations_)) {
577 return false;
578 }
579 return fp->DeSerialize(&improvement_steps_);
580 }
581
582 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
583 // learning rates (by scaling reduction, or layer specific, according to
584 // NF_LAYER_SPECIFIC_LR).
585 void LSTMTrainer::StartSubtrainer(std::stringstream &log_msg) {
586 sub_trainer_ = std::make_unique<LSTMTrainer>();
587 if (!ReadTrainingDump(best_trainer_, *sub_trainer_)) {
588 log_msg << " Failed to revert to previous best for trial!";
589 sub_trainer_.reset();
590 } else {
591 log_msg << " Trial sub_trainer_ from iteration "
592 << sub_trainer_->training_iteration();
593 // Reduce learning rate so it doesn't diverge this time.
594 sub_trainer_->ReduceLearningRates(this, log_msg);
595 // If it fails again, we will wait twice as long before reverting again.
596 int stall_offset =
597 learning_iteration() - sub_trainer_->learning_iteration();
598 stall_iteration_ = learning_iteration() + 2 * stall_offset;
599 sub_trainer_->stall_iteration_ = stall_iteration_;
600 // Re-save the best trainer with the new learning rates and stall iteration.
601 SaveTrainingDump(NO_BEST_TRAINER, *sub_trainer_, &best_trainer_);
602 }
603 }
604
605 // While the sub_trainer_ is behind the current training iteration and its
606 // training error is at least kSubTrainerMarginFraction better than the
607 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
608 // it did anything. If it catches up, and has a better error rate than the
609 // current best, as well as a margin over the current error rate, then the
610 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
611 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
612 // receive any training iterations.
613 SubTrainerResult LSTMTrainer::UpdateSubtrainer(std::stringstream &log_msg) {
614 double training_error = CharError();
615 double sub_error = sub_trainer_->CharError();
616 double sub_margin = (training_error - sub_error) / sub_error;
617 if (sub_margin >= kSubTrainerMarginFraction) {
618 log_msg << " sub_trainer=" << sub_error
619 << " margin=" << 100.0 * sub_margin << "\n";
620 // Catch up to current iteration.
621 int end_iteration = training_iteration();
622 while (sub_trainer_->training_iteration() < end_iteration &&
623 sub_margin >= kSubTrainerMarginFraction) {
624 int target_iteration =
625 sub_trainer_->training_iteration() + kNumPagesPerBatch;
626 while (sub_trainer_->training_iteration() < target_iteration) {
627 sub_trainer_->TrainOnLine(this, false);
628 }
629 std::stringstream batch_log("Sub:");
630 batch_log.imbue(std::locale::classic());
631 sub_trainer_->PrepareLogMsg(batch_log);
632 batch_log << "\n";
633 tprintf("UpdateSubtrainer:%s", batch_log.str().c_str());
634 log_msg << batch_log.str();
635 sub_error = sub_trainer_->CharError();
636 sub_margin = (training_error - sub_error) / sub_error;
637 }
638 if (sub_error < best_error_rate_ &&
639 sub_margin >= kSubTrainerMarginFraction) {
640 // The sub_trainer_ has won the race to a new best. Switch to it.
641 std::vector<char> updated_trainer;
642 SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
643 ReadTrainingDump(updated_trainer, *this);
644 log_msg << " Sub trainer wins at iteration "
645 << training_iteration() << "\n";
646 return STR_REPLACED;
647 }
648 return STR_UPDATED;
649 }
650 return STR_NONE;
651 }
652
653 // Reduces network learning rates, either for everything, or for layers
654 // independently, according to NF_LAYER_SPECIFIC_LR.
655 void LSTMTrainer::ReduceLearningRates(LSTMTrainer *samples_trainer,
656 std::stringstream &log_msg) {
657 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
658 int num_reduced = ReduceLayerLearningRates(
659 kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
660 log_msg << "\nReduced learning rate on layers: " << num_reduced;
661 } else {
662 ScaleLearningRate(kLearningRateDecay);
663 log_msg << "\nReduced learning rate to :" << learning_rate_;
664 }
665 log_msg << "\n";
666 }
667
668 // Considers reducing the learning rate independently for each layer down by
669 // factor(<1), or leaving it the same, by double-training the given number of
670 // samples and minimizing the amount of changing of sign of weight updates.
671 // Even if it looks like all weights should remain the same, an adjustment
672 // will be made to guarantee a different result when reverting to an old best.
673 // Returns the number of layer learning rates that were reduced.
674 int LSTMTrainer::ReduceLayerLearningRates(TFloat factor, int num_samples,
675 LSTMTrainer *samples_trainer) {
676 enum WhichWay {
677 LR_DOWN, // Learning rate will go down by factor.
678 LR_SAME, // Learning rate will stay the same.
679 LR_COUNT // Size of arrays.
680 };
681 std::vector<std::string> layers = EnumerateLayers();
682 int num_layers = layers.size();
683 std::vector<int> num_weights(num_layers);
684 std::vector<TFloat> bad_sums[LR_COUNT];
685 std::vector<TFloat> ok_sums[LR_COUNT];
686 for (int i = 0; i < LR_COUNT; ++i) {
687 bad_sums[i].resize(num_layers, 0.0);
688 ok_sums[i].resize(num_layers, 0.0);
689 }
690 auto momentum_factor = 1 / (1 - momentum_);
691 std::vector<char> orig_trainer;
692 samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer);
693 for (int i = 0; i < num_layers; ++i) {
694 Network *layer = GetLayer(layers[i]);
695 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
696 }
697 int iteration = sample_iteration();
698 for (int s = 0; s < num_samples; ++s) {
699 // Which way will we modify the learning rate?
700 for (int ww = 0; ww < LR_COUNT; ++ww) {
701 // Transfer momentum to learning rate and adjust by the ww factor.
702 auto ww_factor = momentum_factor;
703 if (ww == LR_DOWN) {
704 ww_factor *= factor;
705 }
706 // Make a copy of *this, so we can mess about without damaging anything.
707 LSTMTrainer copy_trainer;
708 samples_trainer->ReadTrainingDump(orig_trainer, copy_trainer);
709 // Clear the updates, doing nothing else.
710 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
711 // Adjust the learning rate in each layer.
712 for (int i = 0; i < num_layers; ++i) {
713 if (num_weights[i] == 0) {
714 continue;
715 }
716 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
717 }
718 copy_trainer.SetIteration(iteration);
719 // Train on the sample, but keep the update in updates_ instead of
720 // applying to the weights.
721 const ImageData *trainingdata =
722 copy_trainer.TrainOnLine(samples_trainer, true);
723 if (trainingdata == nullptr) {
724 continue;
725 }
726 // We'll now use this trainer again for each layer.
727 std::vector<char> updated_trainer;
728 samples_trainer->SaveTrainingDump(LIGHT, copy_trainer, &updated_trainer);
729 for (int i = 0; i < num_layers; ++i) {
730 if (num_weights[i] == 0) {
731 continue;
732 }
733 LSTMTrainer layer_trainer;
734 samples_trainer->ReadTrainingDump(updated_trainer, layer_trainer);
735 Network *layer = layer_trainer.GetLayer(layers[i]);
736 // Update the weights in just the layer, using Adam if enabled.
737 layer->Update(0.0, momentum_, adam_beta_,
738 layer_trainer.training_iteration_ + 1);
739 // Zero the updates matrix again.
740 layer->Update(0.0, 0.0, 0.0, 0);
741 // Train again on the same sample, again holding back the updates.
742 layer_trainer.TrainOnLine(trainingdata, true);
743 // Count the sign changes in the updates in layer vs in copy_trainer.
744 float before_bad = bad_sums[ww][i];
745 float before_ok = ok_sums[ww][i];
746 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
747 &ok_sums[ww][i], &bad_sums[ww][i]);
748 float bad_frac =
749 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
750 if (bad_frac > 0.0f) {
751 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
752 }
753 }
754 }
755 ++iteration;
756 }
757 int num_lowered = 0;
758 for (int i = 0; i < num_layers; ++i) {
759 if (num_weights[i] == 0) {
760 continue;
761 }
762 Network *layer = GetLayer(layers[i]);
763 float lr = GetLayerLearningRate(layers[i]);
764 TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
765 TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
766 TFloat frac_down = bad_sums[LR_DOWN][i] / total_down;
767 TFloat frac_same = bad_sums[LR_SAME][i] / total_same;
768 tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
769 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
770 if (frac_down < frac_same * kImprovementFraction) {
771 tprintf(" REDUCED\n");
772 ScaleLayerLearningRate(layers[i], factor);
773 ++num_lowered;
774 } else {
775 tprintf(" SAME\n");
776 }
777 }
778 if (num_lowered == 0) {
779 // Just lower everything to make sure.
780 for (int i = 0; i < num_layers; ++i) {
781 if (num_weights[i] > 0) {
782 ScaleLayerLearningRate(layers[i], factor);
783 ++num_lowered;
784 }
785 }
786 }
787 return num_lowered;
788 }
789
790 // Converts the string to integer class labels, with appropriate null_char_s
791 // in between if not in SimpleTextOutput mode. Returns false on failure.
792 /* static */
793 bool LSTMTrainer::EncodeString(const std::string &str,
794 const UNICHARSET &unicharset,
795 const UnicharCompress *recoder, bool simple_text,
796 int null_char, std::vector<int> *labels) {
797 if (str.c_str() == nullptr || str.length() <= 0) {
798 tprintf("Empty truth string!\n");
799 return false;
800 }
801 unsigned err_index;
802 std::vector<int> internal_labels;
803 labels->clear();
804 if (!simple_text) {
805 labels->push_back(null_char);
806 }
807 std::string cleaned = unicharset.CleanupString(str.c_str());
808 if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
809 &err_index)) {
810 bool success = true;
811 for (auto internal_label : internal_labels) {
812 if (recoder != nullptr) {
813 // Re-encode labels via recoder.
814 RecodedCharID code;
815 int len = recoder->EncodeUnichar(internal_label, &code);
816 if (len > 0) {
817 for (int j = 0; j < len; ++j) {
818 labels->push_back(code(j));
819 if (!simple_text) {
820 labels->push_back(null_char);
821 }
822 }
823 } else {
824 success = false;
825 err_index = 0;
826 break;
827 }
828 } else {
829 labels->push_back(internal_label);
830 if (!simple_text) {
831 labels->push_back(null_char);
832 }
833 }
834 }
835 if (success) {
836 return true;
837 }
838 }
839 tprintf("Encoding of string failed! Failure bytes:");
840 while (err_index < cleaned.size()) {
841 tprintf(" %x", cleaned[err_index++] & 0xff);
842 }
843 tprintf("\n");
844 return false;
845 }
846
847 // Performs forward-backward on the given trainingdata.
848 // Returns a Trainability enum to indicate the suitability of the sample.
849 Trainability LSTMTrainer::TrainOnLine(const ImageData *trainingdata,
850 bool batch) {
851 NetworkIO fwd_outputs, targets;
852 Trainability trainable =
853 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
854 ++sample_iteration_;
855 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
856 return trainable; // Sample was unusable.
857 }
858 bool debug =
859 debug_interval_ > 0 && training_iteration() % debug_interval_ == 0;
860 // Run backprop on the output.
861 NetworkIO bp_deltas;
862 if (network_->IsTraining() &&
863 (trainable != PERFECT ||
864 training_iteration() >
865 last_perfect_training_iteration_ + perfect_delay_)) {
866 network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
867 network_->Update(learning_rate_, batch ? -1.0f : momentum_, adam_beta_,
868 training_iteration_ + 1);
869 }
870 #ifndef GRAPHICS_DISABLED
871 if (debug_interval_ == 1 && debug_win_ != nullptr) {
872 debug_win_->AwaitEvent(SVET_CLICK);
873 }
874 #endif // !GRAPHICS_DISABLED
875 // Roll the memory of past means.
876 RollErrorBuffers();
877 return trainable;
878 }
879
880 // Prepares the ground truth, runs forward, and prepares the targets.
881 // Returns a Trainability enum to indicate the suitability of the sample.
882 Trainability LSTMTrainer::PrepareForBackward(const ImageData *trainingdata,
883 NetworkIO *fwd_outputs,
884 NetworkIO *targets) {
885 if (trainingdata == nullptr) {
886 tprintf("Null trainingdata.\n");
887 return UNENCODABLE;
888 }
889 // Ensure repeatability of random elements even across checkpoints.
890 bool debug =
891 debug_interval_ > 0 && training_iteration() % debug_interval_ == 0;
892 std::vector<int> truth_labels;
893 if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
894 tprintf("Can't encode transcription: '%s' in language '%s'\n",
895 trainingdata->transcription().c_str(),
896 trainingdata->language().c_str());
897 return UNENCODABLE;
898 }
899 bool upside_down = false;
900 if (randomly_rotate_) {
901 // This ensures consistent training results.
902 SetRandomSeed();
903 upside_down = randomizer_.SignedRand(1.0) > 0.0;
904 if (upside_down) {
905 // Modify the truth labels to match the rotation:
906 // Apart from space and null, increment the label. This changes the
907 // script-id to the same script-id but upside-down.
908 // The labels need to be reversed in order, as the first is now the last.
909 for (auto truth_label : truth_labels) {
910 if (truth_label != UNICHAR_SPACE && truth_label != null_char_) {
911 ++truth_label;
912 }
913 }
914 std::reverse(truth_labels.begin(), truth_labels.end());
915 }
916 }
917 unsigned w = 0;
918 while (w < truth_labels.size() &&
919 (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) {
920 ++w;
921 }
922 if (w == truth_labels.size()) {
923 tprintf("Blank transcription: %s\n", trainingdata->transcription().c_str());
924 return UNENCODABLE;
925 }
926 float image_scale;
927 NetworkIO inputs;
928 bool invert = trainingdata->boxes().empty();
929 if (!RecognizeLine(*trainingdata, invert ? 0.5f : 0.0f, debug, invert, upside_down,
930 &image_scale, &inputs, fwd_outputs)) {
931 tprintf("Image %s not trainable\n", trainingdata->imagefilename().c_str());
932 return UNENCODABLE;
933 }
934 targets->Resize(*fwd_outputs, network_->NumOutputs());
935 LossType loss_type = OutputLossType();
936 if (loss_type == LT_SOFTMAX) {
937 if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
938 tprintf("Compute simple targets failed for %s!\n",
939 trainingdata->imagefilename().c_str());
940 return UNENCODABLE;
941 }
942 } else if (loss_type == LT_CTC) {
943 if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
944 tprintf("Compute CTC targets failed for %s!\n",
945 trainingdata->imagefilename().c_str());
946 return UNENCODABLE;
947 }
948 } else {
949 tprintf("Logistic outputs not implemented yet!\n");
950 return UNENCODABLE;
951 }
952 std::vector<int> ocr_labels;
953 std::vector<int> xcoords;
954 LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
955 // CTC does not produce correct target labels to begin with.
956 if (loss_type != LT_CTC) {
957 LabelsFromOutputs(*targets, &truth_labels, &xcoords);
958 }
959 if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
960 *targets)) {
961 tprintf("Input width was %d\n", inputs.Width());
962 return UNENCODABLE;
963 }
964 std::string ocr_text = DecodeLabels(ocr_labels);
965 std::string truth_text = DecodeLabels(truth_labels);
966 targets->SubtractAllFromFloat(*fwd_outputs);
967 if (debug_interval_ != 0) {
968 if (truth_text != ocr_text) {
969 tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
970 ocr_text.c_str());
971 }
972 }
973 double char_error = ComputeCharError(truth_labels, ocr_labels);
974 double word_error = ComputeWordError(&truth_text, &ocr_text);
975 double delta_error = ComputeErrorRates(*targets, char_error, word_error);
976 if (debug_interval_ != 0) {
977 tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
978 trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
979 }
980 if (delta_error == 0.0) {
981 return PERFECT;
982 }
983 if (targets->AnySuspiciousTruth(kHighConfidence)) {
984 return HI_PRECISION_ERR;
985 }
986 return TRAINABLE;
987 }
988
989 // Writes the trainer to memory, so that the current training state can be
990 // restored. *this must always be the master trainer that retains the only
991 // copy of the training data and language model. trainer is the model that is
992 // actually serialized.
993 bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount,
994 const LSTMTrainer &trainer,
995 std::vector<char> *data) const {
996 TFile fp;
997 fp.OpenWrite(data);
998 return trainer.Serialize(serialize_amount, &mgr_, &fp);
999 }
1000
1001 // Restores the model to *this.
1002 bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager *mgr,
1003 const char *data, int size) {
1004 if (size == 0) {
1005 tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
1006 return false;
1007 }
1008 TFile fp;
1009 fp.Open(data, size);
1010 return DeSerialize(mgr, &fp);
1011 }
1012
1013 // Writes the full recognition traineddata to the given filename.
1014 bool LSTMTrainer::SaveTraineddata(const char *filename) {
1015 std::vector<char> recognizer_data;
1016 SaveRecognitionDump(&recognizer_data);
1017 mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
1018 recognizer_data.size());
1019 return mgr_.SaveFile(filename, SaveDataToFile);
1020 }
1021
1022 // Writes the recognizer to memory, so that it can be used for testing later.
1023 void LSTMTrainer::SaveRecognitionDump(std::vector<char> *data) const {
1024 TFile fp;
1025 fp.OpenWrite(data);
1026 network_->SetEnableTraining(TS_TEMP_DISABLE);
1027 ASSERT_HOST(LSTMRecognizer::Serialize(&mgr_, &fp));
1028 network_->SetEnableTraining(TS_RE_ENABLE);
1029 }
1030
1031 // Returns a suitable filename for a training dump, based on the model_base_,
1032 // best_error_rate_, best_iteration_ and training_iteration_.
1033 std::string LSTMTrainer::DumpFilename() const {
1034 std::stringstream filename;
1035 filename.imbue(std::locale::classic());
1036 filename << model_base_ << std::fixed << std::setprecision(3)
1037 << "_" << best_error_rate_
1038 << "_" << best_iteration_
1039 << "_" << training_iteration_
1040 << ".checkpoint";
1041 return filename.str();
1042 }
1043
1044 // Fills the whole error buffer of the given type with the given value.
1045 void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
1046 for (int i = 0; i < kRollingBufferSize_; ++i) {
1047 error_buffers_[type][i] = new_error;
1048 }
1049 error_rates_[type] = 100.0 * new_error;
1050 }
1051
1052 // Helper generates a map from each current recoder_ code (ie softmax index)
1053 // to the corresponding old_recoder code, or -1 if there isn't one.
1054 std::vector<int> LSTMTrainer::MapRecoder(
1055 const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const {
1056 int num_new_codes = recoder_.code_range();
1057 int num_new_unichars = GetUnicharset().size();
1058 std::vector<int> code_map(num_new_codes, -1);
1059 for (int c = 0; c < num_new_codes; ++c) {
1060 int old_code = -1;
1061 // Find all new unichar_ids that recode to something that includes c.
1062 // The <= is to include the null char, which may be beyond the unicharset.
1063 for (int uid = 0; uid <= num_new_unichars; ++uid) {
1064 RecodedCharID codes;
1065 int length = recoder_.EncodeUnichar(uid, &codes);
1066 int code_index = 0;
1067 while (code_index < length && codes(code_index) != c) {
1068 ++code_index;
1069 }
1070 if (code_index == length) {
1071 continue;
1072 }
1073 // The old unicharset must have the same unichar.
1074 int old_uid =
1075 uid < num_new_unichars
1076 ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
1077 : old_chset.size() - 1;
1078 if (old_uid == INVALID_UNICHAR_ID) {
1079 continue;
1080 }
1081 // The encoding of old_uid at the same code_index is the old code.
1082 RecodedCharID old_codes;
1083 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
1084 old_code = old_codes(code_index);
1085 break;
1086 }
1087 }
1088 code_map[c] = old_code;
1089 }
1090 return code_map;
1091 }
1092
1093 // Private version of InitCharSet above finishes the job after initializing
1094 // the mgr_ data member.
1095 void LSTMTrainer::InitCharSet() {
1096 EmptyConstructor();
1097 training_flags_ = TF_COMPRESS_UNICHARSET;
1098 // Initialize the unicharset and recoder.
1099 if (!LoadCharsets(&mgr_)) {
1100 ASSERT_HOST(
1101 "Must provide a traineddata containing lstm_unicharset and"
1102 " lstm_recoder!\n" != nullptr);
1103 }
1104 SetNullChar();
1105 }
1106
1107 // Helper computes and sets the null_char_.
1108 void LSTMTrainer::SetNullChar() {
1109 null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN
1110 : GetUnicharset().size();
1111 RecodedCharID code;
1112 recoder_.EncodeUnichar(null_char_, &code);
1113 null_char_ = code(0);
1114 }
1115
1116 // Factored sub-constructor sets up reasonable default values.
1117 void LSTMTrainer::EmptyConstructor() {
1118 #ifndef GRAPHICS_DISABLED
1119 align_win_ = nullptr;
1120 target_win_ = nullptr;
1121 ctc_win_ = nullptr;
1122 recon_win_ = nullptr;
1123 #endif
1124 checkpoint_iteration_ = 0;
1125 training_stage_ = 0;
1126 num_training_stages_ = 2;
1127 InitIterations();
1128 }
1129
1130 // Outputs the string and periodically displays the given network inputs
1131 // as an image in the given window, and the corresponding labels at the
1132 // corresponding x_starts.
1133 // Returns false if the truth string is empty.
1134 bool LSTMTrainer::DebugLSTMTraining(const NetworkIO &inputs,
1135 const ImageData &trainingdata,
1136 const NetworkIO &fwd_outputs,
1137 const std::vector<int> &truth_labels,
1138 const NetworkIO &outputs) {
1139 const std::string &truth_text = DecodeLabels(truth_labels);
1140 if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1141 tprintf("Empty truth string at decode time!\n");
1142 return false;
1143 }
1144 if (debug_interval_ != 0) {
1145 // Get class labels, xcoords and string.
1146 std::vector<int> labels;
1147 std::vector<int> xcoords;
1148 LabelsFromOutputs(outputs, &labels, &xcoords);
1149 std::string text = DecodeLabels(labels);
1150 tprintf("Iteration %d: GROUND TRUTH : %s\n", training_iteration(),
1151 truth_text.c_str());
1152 if (truth_text != text) {
1153 tprintf("Iteration %d: ALIGNED TRUTH : %s\n", training_iteration(),
1154 text.c_str());
1155 }
1156 if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1157 tprintf("TRAINING activation path for truth string %s\n",
1158 truth_text.c_str());
1159 DebugActivationPath(outputs, labels, xcoords);
1160 #ifndef GRAPHICS_DISABLED
1161 DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1162 if (OutputLossType() == LT_CTC) {
1163 DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1164 DisplayTargets(outputs, "CTC Targets", &target_win_);
1165 }
1166 #endif
1167 }
1168 }
1169 return true;
1170 }
1171
1172 #ifndef GRAPHICS_DISABLED
1173
1174 // Displays the network targets as line a line graph.
1175 void LSTMTrainer::DisplayTargets(const NetworkIO &targets,
1176 const char *window_name, ScrollView **window) {
1177 int width = targets.Width();
1178 int num_features = targets.NumFeatures();
1179 Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1180 window);
1181 for (int c = 0; c < num_features; ++c) {
1182 int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1183 (*window)->Pen(static_cast<ScrollView::Color>(color));
1184 int start_t = -1;
1185 for (int t = 0; t < width; ++t) {
1186 double target = targets.f(t)[c];
1187 target *= kTargetYScale;
1188 if (target >= 1) {
1189 if (start_t < 0) {
1190 (*window)->SetCursor(t - 1, 0);
1191 start_t = t;
1192 }
1193 (*window)->DrawTo(t, target);
1194 } else if (start_t >= 0) {
1195 (*window)->DrawTo(t, 0);
1196 (*window)->DrawTo(start_t - 1, 0);
1197 start_t = -1;
1198 }
1199 }
1200 if (start_t >= 0) {
1201 (*window)->DrawTo(width, 0);
1202 (*window)->DrawTo(start_t - 1, 0);
1203 }
1204 }
1205 (*window)->Update();
1206 }
1207
1208 #endif // !GRAPHICS_DISABLED
1209
1210 // Builds a no-compromises target where the first positions should be the
1211 // truth labels and the rest is padded with the null_char_.
1212 bool LSTMTrainer::ComputeTextTargets(const NetworkIO &outputs,
1213 const std::vector<int> &truth_labels,
1214 NetworkIO *targets) {
1215 if (truth_labels.size() > targets->Width()) {
1216 tprintf("Error: transcription %s too long to fit into target of width %d\n",
1217 DecodeLabels(truth_labels).c_str(), targets->Width());
1218 return false;
1219 }
1220 int i = 0;
1221 for (auto truth_label : truth_labels) {
1222 targets->SetActivations(i, truth_label, 1.0);
1223 ++i;
1224 }
1225 for (i = truth_labels.size(); i < targets->Width(); ++i) {
1226 targets->SetActivations(i, null_char_, 1.0);
1227 }
1228 return true;
1229 }
1230
1231 // Builds a target using standard CTC. truth_labels should be pre-padded with
1232 // nulls wherever desired. They don't have to be between all labels.
1233 // outputs is input-output, as it gets clipped to minimum probability.
1234 bool LSTMTrainer::ComputeCTCTargets(const std::vector<int> &truth_labels,
1235 NetworkIO *outputs, NetworkIO *targets) {
1236 // Bottom-clip outputs to a minimum probability.
1237 CTC::NormalizeProbs(outputs);
1238 return CTC::ComputeCTCTargets(truth_labels, null_char_,
1239 outputs->float_array(), targets);
1240 }
1241
1242 // Computes network errors, and stores the results in the rolling buffers,
1243 // along with the supplied text_error.
1244 // Returns the delta error of the current sample (not running average.)
1245 double LSTMTrainer::ComputeErrorRates(const NetworkIO &deltas,
1246 double char_error, double word_error) {
1247 UpdateErrorBuffer(ComputeRMSError(deltas), ET_RMS);
1248 // Delta error is the fraction of timesteps with >0.5 error in the top choice
1249 // score. If zero, then the top choice characters are guaranteed correct,
1250 // even when there is residue in the RMS error.
1251 double delta_error = ComputeWinnerError(deltas);
1252 UpdateErrorBuffer(delta_error, ET_DELTA);
1253 UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1254 UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1255 // Skip ratio measures the difference between sample_iteration_ and
1256 // training_iteration_, which reflects the number of unusable samples,
1257 // usually due to unencodable truth text, or the text not fitting in the
1258 // space for the output.
1259 double skip_count = sample_iteration_ - prev_sample_iteration_;
1260 UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1261 return delta_error;
1262 }
1263
1264 // Computes the network activation RMS error rate.
1265 double LSTMTrainer::ComputeRMSError(const NetworkIO &deltas) {
1266 double total_error = 0.0;
1267 int width = deltas.Width();
1268 int num_classes = deltas.NumFeatures();
1269 for (int t = 0; t < width; ++t) {
1270 const float *class_errs = deltas.f(t);
1271 for (int c = 0; c < num_classes; ++c) {
1272 double error = class_errs[c];
1273 total_error += error * error;
1274 }
1275 }
1276 return sqrt(total_error / (width * num_classes));
1277 }
1278
1279 // Computes network activation winner error rate. (Number of values that are
1280 // in error by >= 0.5 divided by number of time-steps.) More closely related
1281 // to final character error than RMS, but still directly calculable from
1282 // just the deltas. Because of the binary nature of the targets, zero winner
1283 // error is a sufficient but not necessary condition for zero char error.
1284 double LSTMTrainer::ComputeWinnerError(const NetworkIO &deltas) {
1285 int num_errors = 0;
1286 int width = deltas.Width();
1287 int num_classes = deltas.NumFeatures();
1288 for (int t = 0; t < width; ++t) {
1289 const float *class_errs = deltas.f(t);
1290 for (int c = 0; c < num_classes; ++c) {
1291 float abs_delta = std::fabs(class_errs[c]);
1292 // TODO(rays) Filtering cases where the delta is very large to cut out
1293 // GT errors doesn't work. Find a better way or get better truth.
1294 if (0.5 <= abs_delta) {
1295 ++num_errors;
1296 }
1297 }
1298 }
1299 return static_cast<double>(num_errors) / width;
1300 }
1301
1302 // Computes a very simple bag of chars char error rate.
1303 double LSTMTrainer::ComputeCharError(const std::vector<int> &truth_str,
1304 const std::vector<int> &ocr_str) {
1305 std::vector<int> label_counts(NumOutputs());
1306 unsigned truth_size = 0;
1307 for (auto ch : truth_str) {
1308 if (ch != null_char_) {
1309 ++label_counts[ch];
1310 ++truth_size;
1311 }
1312 }
1313 for (auto ch : ocr_str) {
1314 if (ch != null_char_) {
1315 --label_counts[ch];
1316 }
1317 }
1318 unsigned char_errors = 0;
1319 for (auto label_count : label_counts) {
1320 char_errors += abs(label_count);
1321 }
1322 // Limit BCER to interval [0,1] and avoid division by zero.
1323 if (truth_size <= char_errors) {
1324 return (char_errors == 0) ? 0.0 : 1.0;
1325 }
1326 return static_cast<double>(char_errors) / truth_size;
1327 }
1328
1329 // Computes word recall error rate using a very simple bag of words algorithm.
1330 // NOTE that this is destructive on both input strings.
1331 double LSTMTrainer::ComputeWordError(std::string *truth_str,
1332 std::string *ocr_str) {
1333 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1334 std::vector<std::string> truth_words = split(*truth_str, ' ');
1335 if (truth_words.empty()) {
1336 return 0.0;
1337 }
1338 std::vector<std::string> ocr_words = split(*ocr_str, ' ');
1339 StrMap word_counts;
1340 for (const auto &truth_word : truth_words) {
1341 std::string truth_word_string(truth_word.c_str());
1342 auto it = word_counts.find(truth_word_string);
1343 if (it == word_counts.end()) {
1344 word_counts.insert(std::make_pair(truth_word_string, 1));
1345 } else {
1346 ++it->second;
1347 }
1348 }
1349 for (const auto &ocr_word : ocr_words) {
1350 std::string ocr_word_string(ocr_word.c_str());
1351 auto it = word_counts.find(ocr_word_string);
1352 if (it == word_counts.end()) {
1353 word_counts.insert(std::make_pair(ocr_word_string, -1));
1354 } else {
1355 --it->second;
1356 }
1357 }
1358 int word_recall_errs = 0;
1359 for (const auto &word_count : word_counts) {
1360 if (word_count.second > 0) {
1361 word_recall_errs += word_count.second;
1362 }
1363 }
1364 return static_cast<double>(word_recall_errs) / truth_words.size();
1365 }
1366
1367 // Updates the error buffer and corresponding mean of the given type with
1368 // the new_error.
1369 void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) {
1370 int index = training_iteration_ % kRollingBufferSize_;
1371 error_buffers_[type][index] = new_error;
1372 // Compute the mean error.
1373 int mean_count =
1374 std::min<int>(training_iteration_ + 1, error_buffers_[type].size());
1375 double buffer_sum = 0.0;
1376 for (int i = 0; i < mean_count; ++i) {
1377 buffer_sum += error_buffers_[type][i];
1378 }
1379 double mean = buffer_sum / mean_count;
1380 // Trim precision to 1/1000 of 1%.
1381 error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1382 }
1383
1384 // Rolls error buffers and reports the current means.
1385 void LSTMTrainer::RollErrorBuffers() {
1386 prev_sample_iteration_ = sample_iteration_;
1387 if (NewSingleError(ET_DELTA) > 0.0) {
1388 ++learning_iteration_;
1389 } else {
1390 last_perfect_training_iteration_ = training_iteration_;
1391 }
1392 ++training_iteration_;
1393 if (debug_interval_ != 0) {
1394 tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1395 error_rates_[ET_RMS], error_rates_[ET_DELTA],
1396 error_rates_[ET_CHAR_ERROR], error_rates_[ET_WORD_RECERR],
1397 error_rates_[ET_SKIP_RATIO]);
1398 }
1399 }
1400
1401 // Given that error_rate is either a new min or max, updates the best/worst
1402 // error rates, and record of progress.
1403 // Tester is an externally supplied callback function that tests on some
1404 // data set with a given model and records the error rates in a graph.
1405 std::string LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1406 const std::vector<char> &model_data,
1407 const TestCallback &tester) {
1408 if (error_rate > best_error_rate_ &&
1409 iteration < best_iteration_ + kErrorGraphInterval) {
1410 // Too soon to record a new point.
1411 if (tester != nullptr && !worst_model_data_.empty()) {
1412 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0],
1413 worst_model_data_.size());
1414 return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1415 } else {
1416 return "";
1417 }
1418 }
1419 std::string result;
1420 // NOTE: there are 2 asymmetries here:
1421 // 1. We are computing the global minimum, but the local maximum in between.
1422 // 2. If the tester returns an empty string, indicating that it is busy,
1423 // call it repeatedly on new local maxima to test the previous min, but
1424 // not the other way around, as there is little point testing the maxima
1425 // between very frequent minima.
1426 if (error_rate < best_error_rate_) {
1427 // This is a new (global) minimum.
1428 if (tester != nullptr && !worst_model_data_.empty()) {
1429 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0],
1430 worst_model_data_.size());
1431 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1432 CurrentTrainingStage());
1433 worst_model_data_.clear();
1434 best_model_data_ = model_data;
1435 }
1436 best_error_rate_ = error_rate;
1437 memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1438 best_iteration_ = iteration;
1439 best_error_history_.push_back(error_rate);
1440 best_error_iterations_.push_back(iteration);
1441 // Compute 2% decay time.
1442 double two_percent_more = error_rate + 2.0;
1443 int i;
1444 for (i = best_error_history_.size() - 1;
1445 i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1446 }
1447 int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1448 improvement_steps_ = iteration - old_iteration;
1449 tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1450 improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1451 old_iteration);
1452 } else if (error_rate > best_error_rate_) {
1453 // This is a new (local) maximum.
1454 if (tester != nullptr) {
1455 if (!best_model_data_.empty()) {
1456 mgr_.OverwriteEntry(TESSDATA_LSTM, &best_model_data_[0],
1457 best_model_data_.size());
1458 result = tester(best_iteration_, best_error_rates_, mgr_,
1459 CurrentTrainingStage());
1460 } else if (!worst_model_data_.empty()) {
1461 // Allow for multiple data points with "worst" error rate.
1462 mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0],
1463 worst_model_data_.size());
1464 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1465 CurrentTrainingStage());
1466 }
1467 if (result.length() > 0) {
1468 best_model_data_.clear();
1469 }
1470 worst_model_data_ = model_data;
1471 }
1472 }
1473 worst_error_rate_ = error_rate;
1474 memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1475 worst_iteration_ = iteration;
1476 return result;
1477 }
1478
1479 } // namespace tesseract.