comparison mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.h @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: lstmtrainer.h
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19 #define TESSERACT_LSTM_LSTMTRAINER_H_
20
21 #include "export.h"
22
23 #include "imagedata.h" // for DocumentCache
24 #include "lstmrecognizer.h"
25 #include "rect.h"
26
27 #include <functional> // for std::function
28 #include <sstream> // for std::stringstream
29
30 namespace tesseract {
31
32 class LSTM;
33 class LSTMTester;
34 class LSTMTrainer;
35 class Parallel;
36 class Reversed;
37 class Softmax;
38 class Series;
39
40 // Enum for the types of errors that are counted.
41 enum ErrorTypes {
42 ET_RMS, // RMS activation error.
43 ET_DELTA, // Number of big errors in deltas.
44 ET_WORD_RECERR, // Output text string word recall error.
45 ET_CHAR_ERROR, // Output text string total char error.
46 ET_SKIP_RATIO, // Fraction of samples skipped.
47 ET_COUNT // For array sizing.
48 };
49
50 // Enum for the trainability_ flags.
51 enum Trainability {
52 TRAINABLE, // Non-zero delta error.
53 PERFECT, // Zero delta error.
54 UNENCODABLE, // Not trainable due to coding/alignment trouble.
55 HI_PRECISION_ERR, // Hi confidence disagreement.
56 NOT_BOXED, // Early in training and has no character boxes.
57 };
58
59 // Enum to define the amount of data to get serialized.
60 enum SerializeAmount {
61 LIGHT, // Minimal data for remote training.
62 NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
63 FULL, // All data including best_trainer_.
64 };
65
66 // Enum to indicate how the sub_trainer_ training went.
67 enum SubTrainerResult {
68 STR_NONE, // Did nothing as not good enough.
69 STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
70 STR_REPLACED // Subtrainer replaced *this.
71 };
72
73 class LSTMTrainer;
74 // Function to compute and record error rates on some external test set(s).
75 // Args are: iteration, mean errors, model, training stage.
76 // Returns a string containing logging information about the tests.
77 using TestCallback = std::function<std::string(int, const double *,
78 const TessdataManager &, int)>;
79
80 // Trainer class for LSTM networks. Most of the effort is in creating the
81 // ideal target outputs from the transcription. A box file is used if it is
82 // available, otherwise estimates of the char widths from the unicharset are
83 // used to guide a DP search for the best fit to the transcription.
84 class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
85 public:
86 LSTMTrainer();
87 LSTMTrainer(const std::string &model_base,
88 const std::string &checkpoint_name,
89 int debug_interval, int64_t max_memory);
90 virtual ~LSTMTrainer();
91
92 // Tries to deserialize a trainer from the given file and silently returns
93 // false in case of failure. If old_traineddata is not null, then it is
94 // assumed that the character set is to be re-mapped from old_traineddata to
95 // the new, with consequent change in weight matrices etc.
96 bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata);
97
98 // Initializes the character set encode/decode mechanism directly from a
99 // previously setup traineddata containing dawgs, UNICHARSET and
100 // UnicharCompress. Note: Call before InitNetwork!
101 bool InitCharSet(const std::string &traineddata_path) {
102 bool success = mgr_.Init(traineddata_path.c_str());
103 if (success) {
104 InitCharSet();
105 }
106 return success;
107 }
108 void InitCharSet(const TessdataManager &mgr) {
109 mgr_ = mgr;
110 InitCharSet();
111 }
112
113 // Initializes the trainer with a network_spec in the network description
114 // net_flags control network behavior according to the NetworkFlags enum.
115 // There isn't really much difference between them - only where the effects
116 // are implemented.
117 // For other args see NetworkBuilder::InitNetwork.
118 // Note: Be sure to call InitCharSet before InitNetwork!
119 bool InitNetwork(const char *network_spec, int append_index, int net_flags,
120 float weight_range, float learning_rate, float momentum,
121 float adam_beta);
122 // Resets all the iteration counters for fine tuning or training a head,
123 // where we want the error reporting to reset.
124 void InitIterations();
125
126 // Accessors.
127 double ActivationError() const {
128 return error_rates_[ET_DELTA];
129 }
130 double CharError() const {
131 return error_rates_[ET_CHAR_ERROR];
132 }
133 const double *error_rates() const {
134 return error_rates_;
135 }
136 double best_error_rate() const {
137 return best_error_rate_;
138 }
139 int best_iteration() const {
140 return best_iteration_;
141 }
142 int learning_iteration() const {
143 return learning_iteration_;
144 }
145 int32_t improvement_steps() const {
146 return improvement_steps_;
147 }
148 void set_perfect_delay(int delay) {
149 perfect_delay_ = delay;
150 }
151 const std::vector<char> &best_trainer() const {
152 return best_trainer_;
153 }
154 // Returns the error that was just calculated by PrepareForBackward.
155 double NewSingleError(ErrorTypes type) const {
156 return error_buffers_[type][training_iteration() % kRollingBufferSize_];
157 }
158 // Returns the error that was just calculated by TrainOnLine. Since
159 // TrainOnLine rolls the error buffers, this is one further back than
160 // NewSingleError.
161 double LastSingleError(ErrorTypes type) const {
162 return error_buffers_[type]
163 [(training_iteration() + kRollingBufferSize_ - 1) %
164 kRollingBufferSize_];
165 }
166 const DocumentCache &training_data() const {
167 return training_data_;
168 }
169 DocumentCache *mutable_training_data() {
170 return &training_data_;
171 }
172
173 // If the training sample is usable, grid searches for the optimal
174 // dict_ratio/cert_offset, and returns the results in a string of space-
175 // separated triplets of ratio,offset=worderr.
176 Trainability GridSearchDictParams(
177 const ImageData *trainingdata, int iteration, double min_dict_ratio,
178 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
179 double cert_offset_step, double max_cert_offset, std::string &results);
180
181 // Provides output on the distribution of weight values.
182 void DebugNetwork();
183
184 // Loads a set of lstmf files that were created using the lstm.train config to
185 // tesseract into memory ready for training. Returns false if nothing was
186 // loaded.
187 bool LoadAllTrainingData(const std::vector<std::string> &filenames,
188 CachingStrategy cache_strategy,
189 bool randomly_rotate);
190
191 // Keeps track of best and locally worst error rate, using internally computed
192 // values. See MaintainCheckpointsSpecific for more detail.
193 bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg);
194 // Keeps track of best and locally worst error_rate (whatever it is) and
195 // launches tests using rec_model, when a new min or max is reached.
196 // Writes checkpoints using train_model at appropriate times and builds and
197 // returns a log message to indicate progress. Returns false if nothing
198 // interesting happened.
199 bool MaintainCheckpointsSpecific(int iteration,
200 const std::vector<char> *train_model,
201 const std::vector<char> *rec_model,
202 TestCallback tester, std::stringstream &log_msg);
203 // Builds a progress message with current error rates.
204 void PrepareLogMsg(std::stringstream &log_msg) const;
205 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
206 // sample_iteration() to the log_msg.
207 void LogIterations(const char *intro_str, std::stringstream &log_msg) const;
208
209 // TODO(rays) Add curriculum learning.
210 // Returns true and increments the training_stage_ if the error rate has just
211 // passed through the given threshold for the first time.
212 bool TransitionTrainingStage(float error_threshold);
213 // Returns the current training stage.
214 int CurrentTrainingStage() const {
215 return training_stage_;
216 }
217
218 // Writes to the given file. Returns false in case of error.
219 bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr,
220 TFile *fp) const;
221 // Reads from the given file. Returns false in case of error.
222 bool DeSerialize(const TessdataManager *mgr, TFile *fp);
223
224 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
225 // learning rates (by scaling reduction, or layer specific, according to
226 // NF_LAYER_SPECIFIC_LR).
227 void StartSubtrainer(std::stringstream &log_msg);
228 // While the sub_trainer_ is behind the current training iteration and its
229 // training error is at least kSubTrainerMarginFraction better than the
230 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
231 // it did anything. If it catches up, and has a better error rate than the
232 // current best, as well as a margin over the current error rate, then the
233 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
234 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
235 // receive any training iterations.
236 SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg);
237 // Reduces network learning rates, either for everything, or for layers
238 // independently, according to NF_LAYER_SPECIFIC_LR.
239 void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg);
240 // Considers reducing the learning rate independently for each layer down by
241 // factor(<1), or leaving it the same, by double-training the given number of
242 // samples and minimizing the amount of changing of sign of weight updates.
243 // Even if it looks like all weights should remain the same, an adjustment
244 // will be made to guarantee a different result when reverting to an old best.
245 // Returns the number of layer learning rates that were reduced.
246 int ReduceLayerLearningRates(TFloat factor, int num_samples,
247 LSTMTrainer *samples_trainer);
248
249 // Converts the string to integer class labels, with appropriate null_char_s
250 // in between if not in SimpleTextOutput mode. Returns false on failure.
251 bool EncodeString(const std::string &str, std::vector<int> *labels) const {
252 return EncodeString(str, GetUnicharset(),
253 IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(),
254 null_char_, labels);
255 }
256 // Static version operates on supplied unicharset, encoder, simple_text.
257 static bool EncodeString(const std::string &str, const UNICHARSET &unicharset,
258 const UnicharCompress *recoder, bool simple_text,
259 int null_char, std::vector<int> *labels);
260
261 // Performs forward-backward on the given trainingdata.
262 // Returns the sample that was used or nullptr if the next sample was deemed
263 // unusable. samples_trainer could be this or an alternative trainer that
264 // holds the training samples.
265 const ImageData *TrainOnLine(LSTMTrainer *samples_trainer, bool batch) {
266 int sample_index = sample_iteration();
267 const ImageData *image =
268 samples_trainer->training_data_.GetPageBySerial(sample_index);
269 if (image != nullptr) {
270 Trainability trainable = TrainOnLine(image, batch);
271 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
272 return nullptr; // Sample was unusable.
273 }
274 } else {
275 ++sample_iteration_;
276 }
277 return image;
278 }
279 Trainability TrainOnLine(const ImageData *trainingdata, bool batch);
280
281 // Prepares the ground truth, runs forward, and prepares the targets.
282 // Returns a Trainability enum to indicate the suitability of the sample.
283 Trainability PrepareForBackward(const ImageData *trainingdata,
284 NetworkIO *fwd_outputs, NetworkIO *targets);
285
286 // Writes the trainer to memory, so that the current training state can be
287 // restored. *this must always be the master trainer that retains the only
288 // copy of the training data and language model. trainer is the model that is
289 // actually serialized.
290 bool SaveTrainingDump(SerializeAmount serialize_amount,
291 const LSTMTrainer &trainer,
292 std::vector<char> *data) const;
293
294 // Reads previously saved trainer from memory. *this must always be the
295 // master trainer that retains the only copy of the training data and
296 // language model. trainer is the model that is restored.
297 bool ReadTrainingDump(const std::vector<char> &data,
298 LSTMTrainer &trainer) const {
299 if (data.empty()) {
300 return false;
301 }
302 return ReadSizedTrainingDump(&data[0], data.size(), trainer);
303 }
304 bool ReadSizedTrainingDump(const char *data, int size,
305 LSTMTrainer &trainer) const {
306 return trainer.ReadLocalTrainingDump(&mgr_, data, size);
307 }
308 // Restores the model to *this.
309 bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data,
310 int size);
311
312 // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
313 void SetupCheckpointInfo();
314
315 // Writes the full recognition traineddata to the given filename.
316 bool SaveTraineddata(const char *filename);
317
318 // Writes the recognizer to memory, so that it can be used for testing later.
319 void SaveRecognitionDump(std::vector<char> *data) const;
320
321 // Returns a suitable filename for a training dump, based on the model_base_,
322 // the iteration and the error rates.
323 std::string DumpFilename() const;
324
325 // Fills the whole error buffer of the given type with the given value.
326 void FillErrorBuffer(double new_error, ErrorTypes type);
327 // Helper generates a map from each current recoder_ code (ie softmax index)
328 // to the corresponding old_recoder code, or -1 if there isn't one.
329 std::vector<int> MapRecoder(const UNICHARSET &old_chset,
330 const UnicharCompress &old_recoder) const;
331
332 protected:
333 // Private version of InitCharSet above finishes the job after initializing
334 // the mgr_ data member.
335 void InitCharSet();
336 // Helper computes and sets the null_char_.
337 void SetNullChar();
338
339 // Factored sub-constructor sets up reasonable default values.
340 void EmptyConstructor();
341
342 // Outputs the string and periodically displays the given network inputs
343 // as an image in the given window, and the corresponding labels at the
344 // corresponding x_starts.
345 // Returns false if the truth string is empty.
346 bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata,
347 const NetworkIO &fwd_outputs,
348 const std::vector<int> &truth_labels,
349 const NetworkIO &outputs);
350 // Displays the network targets as line a line graph.
351 void DisplayTargets(const NetworkIO &targets, const char *window_name,
352 ScrollView **window);
353
354 // Builds a no-compromises target where the first positions should be the
355 // truth labels and the rest is padded with the null_char_.
356 bool ComputeTextTargets(const NetworkIO &outputs,
357 const std::vector<int> &truth_labels,
358 NetworkIO *targets);
359
360 // Builds a target using standard CTC. truth_labels should be pre-padded with
361 // nulls wherever desired. They don't have to be between all labels.
362 // outputs is input-output, as it gets clipped to minimum probability.
363 bool ComputeCTCTargets(const std::vector<int> &truth_labels,
364 NetworkIO *outputs, NetworkIO *targets);
365
366 // Computes network errors, and stores the results in the rolling buffers,
367 // along with the supplied text_error.
368 // Returns the delta error of the current sample (not running average.)
369 double ComputeErrorRates(const NetworkIO &deltas, double char_error,
370 double word_error);
371
372 // Computes the network activation RMS error rate.
373 double ComputeRMSError(const NetworkIO &deltas);
374
375 // Computes network activation winner error rate. (Number of values that are
376 // in error by >= 0.5 divided by number of time-steps.) More closely related
377 // to final character error than RMS, but still directly calculable from
378 // just the deltas. Because of the binary nature of the targets, zero winner
379 // error is a sufficient but not necessary condition for zero char error.
380 double ComputeWinnerError(const NetworkIO &deltas);
381
382 // Computes a very simple bag of chars char error rate.
383 double ComputeCharError(const std::vector<int> &truth_str,
384 const std::vector<int> &ocr_str);
385 // Computes a very simple bag of words word recall error rate.
386 // NOTE that this is destructive on both input strings.
387 double ComputeWordError(std::string *truth_str, std::string *ocr_str);
388
389 // Updates the error buffer and corresponding mean of the given type with
390 // the new_error.
391 void UpdateErrorBuffer(double new_error, ErrorTypes type);
392
393 // Rolls error buffers and reports the current means.
394 void RollErrorBuffers();
395
396 // Given that error_rate is either a new min or max, updates the best/worst
397 // error rates, and record of progress.
398 std::string UpdateErrorGraph(int iteration, double error_rate,
399 const std::vector<char> &model_data,
400 const TestCallback &tester);
401
402 protected:
403 #ifndef GRAPHICS_DISABLED
404 // Alignment display window.
405 ScrollView *align_win_;
406 // CTC target display window.
407 ScrollView *target_win_;
408 // CTC output display window.
409 ScrollView *ctc_win_;
410 // Reconstructed image window.
411 ScrollView *recon_win_;
412 #endif
413 // How often to display a debug image.
414 int debug_interval_;
415 // Iteration at which the last checkpoint was dumped.
416 int checkpoint_iteration_;
417 // Basename of files to save best models to.
418 std::string model_base_;
419 // Checkpoint filename.
420 std::string checkpoint_name_;
421 // Training data.
422 bool randomly_rotate_;
423 DocumentCache training_data_;
424 // Name to use when saving best_trainer_.
425 std::string best_model_name_;
426 // Number of available training stages.
427 int num_training_stages_;
428
429 // ===Serialized data to ensure that a restart produces the same results.===
430 // These members are only serialized when serialize_amount != LIGHT.
431 // Best error rate so far.
432 double best_error_rate_;
433 // Snapshot of all error rates at best_iteration_.
434 double best_error_rates_[ET_COUNT];
435 // Iteration of best_error_rate_.
436 int best_iteration_;
437 // Worst error rate since best_error_rate_.
438 double worst_error_rate_;
439 // Snapshot of all error rates at worst_iteration_.
440 double worst_error_rates_[ET_COUNT];
441 // Iteration of worst_error_rate_.
442 int worst_iteration_;
443 // Iteration at which the process will be thought stalled.
444 int stall_iteration_;
445 // Saved recognition models for computing test error for graph points.
446 std::vector<char> best_model_data_;
447 std::vector<char> worst_model_data_;
448 // Saved trainer for reverting back to last known best.
449 std::vector<char> best_trainer_;
450 // A subsidiary trainer running with a different learning rate until either
451 // *this or sub_trainer_ hits a new best.
452 std::unique_ptr<LSTMTrainer> sub_trainer_;
453 // Error rate at which last best model was dumped.
454 float error_rate_of_last_saved_best_;
455 // Current stage of training.
456 int training_stage_;
457 // History of best error rate against iteration. Used for computing the
458 // number of steps to each 2% improvement.
459 std::vector<double> best_error_history_;
460 std::vector<int32_t> best_error_iterations_;
461 // Number of iterations since the best_error_rate_ was 2% more than it is now.
462 int32_t improvement_steps_;
463 // Number of iterations that yielded a non-zero delta error and thus provided
464 // significant learning. learning_iteration_ <= training_iteration_.
465 // learning_iteration_ is used to measure rate of learning progress.
466 int learning_iteration_;
467 // Saved value of sample_iteration_ before looking for the next sample.
468 int prev_sample_iteration_;
469 // How often to include a PERFECT training sample in backprop.
470 // A PERFECT training sample is used if the current
471 // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
472 // so with perfect_delay_ == 0, all samples are used, and with
473 // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
474 int perfect_delay_;
475 // Value of training_iteration_ at which the last PERFECT training sample
476 // was used in back prop.
477 int last_perfect_training_iteration_;
478 // Rolling buffers storing recent training errors are indexed by
479 // training_iteration % kRollingBufferSize_.
480 static const int kRollingBufferSize_ = 1000;
481 std::vector<double> error_buffers_[ET_COUNT];
482 // Rounded mean percent trailing training errors in the buffers.
483 double error_rates_[ET_COUNT]; // RMS training error.
484 // Traineddata file with optional dawgs + UNICHARSET and recoder.
485 TessdataManager mgr_;
486 };
487
488 } // namespace tesseract.
489
490 #endif // TESSERACT_LSTM_LSTMTRAINER_H_