diff mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.h @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.h	Mon Sep 15 11:43:07 2025 +0200
@@ -0,0 +1,490 @@
+///////////////////////////////////////////////////////////////////////
+// File:        lstmtrainer.h
+// Description: Top-level line trainer class for LSTM-based networks.
+// Author:      Ray Smith
+//
+// (C) Copyright 2013, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+///////////////////////////////////////////////////////////////////////
+
+#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
+#define TESSERACT_LSTM_LSTMTRAINER_H_
+
+#include "export.h"
+
+#include "imagedata.h" // for DocumentCache
+#include "lstmrecognizer.h"
+#include "rect.h"
+
+#include <functional> // for std::function
+#include <sstream>    // for std::stringstream
+
+namespace tesseract {
+
+class LSTM;
+class LSTMTester;
+class LSTMTrainer;
+class Parallel;
+class Reversed;
+class Softmax;
+class Series;
+
+// Enum for the types of errors that are counted.
+enum ErrorTypes {
+  ET_RMS,         // RMS activation error.
+  ET_DELTA,       // Number of big errors in deltas.
+  ET_WORD_RECERR, // Output text string word recall error.
+  ET_CHAR_ERROR,  // Output text string total char error.
+  ET_SKIP_RATIO,  // Fraction of samples skipped.
+  ET_COUNT        // For array sizing.
+};
+
+// Enum for the trainability_ flags.
+enum Trainability {
+  TRAINABLE,        // Non-zero delta error.
+  PERFECT,          // Zero delta error.
+  UNENCODABLE,      // Not trainable due to coding/alignment trouble.
+  HI_PRECISION_ERR, // Hi confidence disagreement.
+  NOT_BOXED,        // Early in training and has no character boxes.
+};
+
+// Enum to define the amount of data to get serialized.
+enum SerializeAmount {
+  LIGHT,           // Minimal data for remote training.
+  NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
+  FULL,            // All data including best_trainer_.
+};
+
+// Enum to indicate how the sub_trainer_ training went.
+enum SubTrainerResult {
+  STR_NONE,    // Did nothing as not good enough.
+  STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
+  STR_REPLACED // Subtrainer replaced *this.
+};
+
+class LSTMTrainer;
+// Function to compute and record error rates on some external test set(s).
+// Args are: iteration, mean errors, model, training stage.
+// Returns a string containing logging information about the tests.
+using TestCallback = std::function<std::string(int, const double *,
+                                               const TessdataManager &, int)>;
+
+// Trainer class for LSTM networks. Most of the effort is in creating the
+// ideal target outputs from the transcription. A box file is used if it is
+// available, otherwise estimates of the char widths from the unicharset are
+// used to guide a DP search for the best fit to the transcription.
+class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
+public:
+  LSTMTrainer();
+  LSTMTrainer(const std::string &model_base,
+              const std::string &checkpoint_name,
+              int debug_interval, int64_t max_memory);
+  virtual ~LSTMTrainer();
+
+  // Tries to deserialize a trainer from the given file and silently returns
+  // false in case of failure. If old_traineddata is not null, then it is
+  // assumed that the character set is to be re-mapped from old_traineddata to
+  // the new, with consequent change in weight matrices etc.
+  bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata);
+
+  // Initializes the character set encode/decode mechanism directly from a
+  // previously setup traineddata containing dawgs, UNICHARSET and
+  // UnicharCompress. Note: Call before InitNetwork!
+  bool InitCharSet(const std::string &traineddata_path) {
+    bool success = mgr_.Init(traineddata_path.c_str());
+    if (success) {
+      InitCharSet();
+    }
+    return success;
+  }
+  void InitCharSet(const TessdataManager &mgr) {
+    mgr_ = mgr;
+    InitCharSet();
+  }
+
+  // Initializes the trainer with a network_spec in the network description
+  // net_flags control network behavior according to the NetworkFlags enum.
+  // There isn't really much difference between them - only where the effects
+  // are implemented.
+  // For other args see NetworkBuilder::InitNetwork.
+  // Note: Be sure to call InitCharSet before InitNetwork!
+  bool InitNetwork(const char *network_spec, int append_index, int net_flags,
+                   float weight_range, float learning_rate, float momentum,
+                   float adam_beta);
+  // Resets all the iteration counters for fine tuning or training a head,
+  // where we want the error reporting to reset.
+  void InitIterations();
+
+  // Accessors.
+  double ActivationError() const {
+    return error_rates_[ET_DELTA];
+  }
+  double CharError() const {
+    return error_rates_[ET_CHAR_ERROR];
+  }
+  const double *error_rates() const {
+    return error_rates_;
+  }
+  double best_error_rate() const {
+    return best_error_rate_;
+  }
+  int best_iteration() const {
+    return best_iteration_;
+  }
+  int learning_iteration() const {
+    return learning_iteration_;
+  }
+  int32_t improvement_steps() const {
+    return improvement_steps_;
+  }
+  void set_perfect_delay(int delay) {
+    perfect_delay_ = delay;
+  }
+  const std::vector<char> &best_trainer() const {
+    return best_trainer_;
+  }
+  // Returns the error that was just calculated by PrepareForBackward.
+  double NewSingleError(ErrorTypes type) const {
+    return error_buffers_[type][training_iteration() % kRollingBufferSize_];
+  }
+  // Returns the error that was just calculated by TrainOnLine. Since
+  // TrainOnLine rolls the error buffers, this is one further back than
+  // NewSingleError.
+  double LastSingleError(ErrorTypes type) const {
+    return error_buffers_[type]
+                         [(training_iteration() + kRollingBufferSize_ - 1) %
+                          kRollingBufferSize_];
+  }
+  const DocumentCache &training_data() const {
+    return training_data_;
+  }
+  DocumentCache *mutable_training_data() {
+    return &training_data_;
+  }
+
+  // If the training sample is usable, grid searches for the optimal
+  // dict_ratio/cert_offset, and returns the results in a string of space-
+  // separated triplets of ratio,offset=worderr.
+  Trainability GridSearchDictParams(
+      const ImageData *trainingdata, int iteration, double min_dict_ratio,
+      double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
+      double cert_offset_step, double max_cert_offset, std::string &results);
+
+  // Provides output on the distribution of weight values.
+  void DebugNetwork();
+
+  // Loads a set of lstmf files that were created using the lstm.train config to
+  // tesseract into memory ready for training. Returns false if nothing was
+  // loaded.
+  bool LoadAllTrainingData(const std::vector<std::string> &filenames,
+                           CachingStrategy cache_strategy,
+                           bool randomly_rotate);
+
+  // Keeps track of best and locally worst error rate, using internally computed
+  // values. See MaintainCheckpointsSpecific for more detail.
+  bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg);
+  // Keeps track of best and locally worst error_rate (whatever it is) and
+  // launches tests using rec_model, when a new min or max is reached.
+  // Writes checkpoints using train_model at appropriate times and builds and
+  // returns a log message to indicate progress. Returns false if nothing
+  // interesting happened.
+  bool MaintainCheckpointsSpecific(int iteration,
+                                   const std::vector<char> *train_model,
+                                   const std::vector<char> *rec_model,
+                                   TestCallback tester, std::stringstream &log_msg);
+  // Builds a progress message with current error rates.
+  void PrepareLogMsg(std::stringstream &log_msg) const;
+  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
+  // sample_iteration() to the log_msg.
+  void LogIterations(const char *intro_str, std::stringstream &log_msg) const;
+
+  // TODO(rays) Add curriculum learning.
+  // Returns true and increments the training_stage_ if the error rate has just
+  // passed through the given threshold for the first time.
+  bool TransitionTrainingStage(float error_threshold);
+  // Returns the current training stage.
+  int CurrentTrainingStage() const {
+    return training_stage_;
+  }
+
+  // Writes to the given file. Returns false in case of error.
+  bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr,
+                 TFile *fp) const;
+  // Reads from the given file. Returns false in case of error.
+  bool DeSerialize(const TessdataManager *mgr, TFile *fp);
+
+  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
+  // learning rates (by scaling reduction, or layer specific, according to
+  // NF_LAYER_SPECIFIC_LR).
+  void StartSubtrainer(std::stringstream &log_msg);
+  // While the sub_trainer_ is behind the current training iteration and its
+  // training error is at least kSubTrainerMarginFraction better than the
+  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
+  // it did anything. If it catches up, and has a better error rate than the
+  // current best, as well as a margin over the current error rate, then the
+  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
+  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
+  // receive any training iterations.
+  SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg);
+  // Reduces network learning rates, either for everything, or for layers
+  // independently, according to NF_LAYER_SPECIFIC_LR.
+  void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg);
+  // Considers reducing the learning rate independently for each layer down by
+  // factor(<1), or leaving it the same, by double-training the given number of
+  // samples and minimizing the amount of changing of sign of weight updates.
+  // Even if it looks like all weights should remain the same, an adjustment
+  // will be made to guarantee a different result when reverting to an old best.
+  // Returns the number of layer learning rates that were reduced.
+  int ReduceLayerLearningRates(TFloat factor, int num_samples,
+                               LSTMTrainer *samples_trainer);
+
+  // Converts the string to integer class labels, with appropriate null_char_s
+  // in between if not in SimpleTextOutput mode. Returns false on failure.
+  bool EncodeString(const std::string &str, std::vector<int> *labels) const {
+    return EncodeString(str, GetUnicharset(),
+                        IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(),
+                        null_char_, labels);
+  }
+  // Static version operates on supplied unicharset, encoder, simple_text.
+  static bool EncodeString(const std::string &str, const UNICHARSET &unicharset,
+                           const UnicharCompress *recoder, bool simple_text,
+                           int null_char, std::vector<int> *labels);
+
+  // Performs forward-backward on the given trainingdata.
+  // Returns the sample that was used or nullptr if the next sample was deemed
+  // unusable. samples_trainer could be this or an alternative trainer that
+  // holds the training samples.
+  const ImageData *TrainOnLine(LSTMTrainer *samples_trainer, bool batch) {
+    int sample_index = sample_iteration();
+    const ImageData *image =
+        samples_trainer->training_data_.GetPageBySerial(sample_index);
+    if (image != nullptr) {
+      Trainability trainable = TrainOnLine(image, batch);
+      if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
+        return nullptr; // Sample was unusable.
+      }
+    } else {
+      ++sample_iteration_;
+    }
+    return image;
+  }
+  Trainability TrainOnLine(const ImageData *trainingdata, bool batch);
+
+  // Prepares the ground truth, runs forward, and prepares the targets.
+  // Returns a Trainability enum to indicate the suitability of the sample.
+  Trainability PrepareForBackward(const ImageData *trainingdata,
+                                  NetworkIO *fwd_outputs, NetworkIO *targets);
+
+  // Writes the trainer to memory, so that the current training state can be
+  // restored.  *this must always be the master trainer that retains the only
+  // copy of the training data and language model. trainer is the model that is
+  // actually serialized.
+  bool SaveTrainingDump(SerializeAmount serialize_amount,
+                        const LSTMTrainer &trainer,
+                        std::vector<char> *data) const;
+
+  // Reads previously saved trainer from memory. *this must always be the
+  // master trainer that retains the only copy of the training data and
+  // language model. trainer is the model that is restored.
+  bool ReadTrainingDump(const std::vector<char> &data,
+                        LSTMTrainer &trainer) const {
+    if (data.empty()) {
+      return false;
+    }
+    return ReadSizedTrainingDump(&data[0], data.size(), trainer);
+  }
+  bool ReadSizedTrainingDump(const char *data, int size,
+                             LSTMTrainer &trainer) const {
+    return trainer.ReadLocalTrainingDump(&mgr_, data, size);
+  }
+  // Restores the model to *this.
+  bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data,
+                             int size);
+
+  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
+  void SetupCheckpointInfo();
+
+  // Writes the full recognition traineddata to the given filename.
+  bool SaveTraineddata(const char *filename);
+
+  // Writes the recognizer to memory, so that it can be used for testing later.
+  void SaveRecognitionDump(std::vector<char> *data) const;
+
+  // Returns a suitable filename for a training dump, based on the model_base_,
+  // the iteration and the error rates.
+  std::string DumpFilename() const;
+
+  // Fills the whole error buffer of the given type with the given value.
+  void FillErrorBuffer(double new_error, ErrorTypes type);
+  // Helper generates a map from each current recoder_ code (ie softmax index)
+  // to the corresponding old_recoder code, or -1 if there isn't one.
+  std::vector<int> MapRecoder(const UNICHARSET &old_chset,
+                              const UnicharCompress &old_recoder) const;
+
+protected:
+  // Private version of InitCharSet above finishes the job after initializing
+  // the mgr_ data member.
+  void InitCharSet();
+  // Helper computes and sets the null_char_.
+  void SetNullChar();
+
+  // Factored sub-constructor sets up reasonable default values.
+  void EmptyConstructor();
+
+  // Outputs the string and periodically displays the given network inputs
+  // as an image in the given window, and the corresponding labels at the
+  // corresponding x_starts.
+  // Returns false if the truth string is empty.
+  bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata,
+                         const NetworkIO &fwd_outputs,
+                         const std::vector<int> &truth_labels,
+                         const NetworkIO &outputs);
+  // Displays the network targets as line a line graph.
+  void DisplayTargets(const NetworkIO &targets, const char *window_name,
+                      ScrollView **window);
+
+  // Builds a no-compromises target where the first positions should be the
+  // truth labels and the rest is padded with the null_char_.
+  bool ComputeTextTargets(const NetworkIO &outputs,
+                          const std::vector<int> &truth_labels,
+                          NetworkIO *targets);
+
+  // Builds a target using standard CTC. truth_labels should be pre-padded with
+  // nulls wherever desired. They don't have to be between all labels.
+  // outputs is input-output, as it gets clipped to minimum probability.
+  bool ComputeCTCTargets(const std::vector<int> &truth_labels,
+                         NetworkIO *outputs, NetworkIO *targets);
+
+  // Computes network errors, and stores the results in the rolling buffers,
+  // along with the supplied text_error.
+  // Returns the delta error of the current sample (not running average.)
+  double ComputeErrorRates(const NetworkIO &deltas, double char_error,
+                           double word_error);
+
+  // Computes the network activation RMS error rate.
+  double ComputeRMSError(const NetworkIO &deltas);
+
+  // Computes network activation winner error rate. (Number of values that are
+  // in error by >= 0.5 divided by number of time-steps.) More closely related
+  // to final character error than RMS, but still directly calculable from
+  // just the deltas. Because of the binary nature of the targets, zero winner
+  // error is a sufficient but not necessary condition for zero char error.
+  double ComputeWinnerError(const NetworkIO &deltas);
+
+  // Computes a very simple bag of chars char error rate.
+  double ComputeCharError(const std::vector<int> &truth_str,
+                          const std::vector<int> &ocr_str);
+  // Computes a very simple bag of words word recall error rate.
+  // NOTE that this is destructive on both input strings.
+  double ComputeWordError(std::string *truth_str, std::string *ocr_str);
+
+  // Updates the error buffer and corresponding mean of the given type with
+  // the new_error.
+  void UpdateErrorBuffer(double new_error, ErrorTypes type);
+
+  // Rolls error buffers and reports the current means.
+  void RollErrorBuffers();
+
+  // Given that error_rate is either a new min or max, updates the best/worst
+  // error rates, and record of progress.
+  std::string UpdateErrorGraph(int iteration, double error_rate,
+                               const std::vector<char> &model_data,
+                               const TestCallback &tester);
+
+protected:
+#ifndef GRAPHICS_DISABLED
+  // Alignment display window.
+  ScrollView *align_win_;
+  // CTC target display window.
+  ScrollView *target_win_;
+  // CTC output display window.
+  ScrollView *ctc_win_;
+  // Reconstructed image window.
+  ScrollView *recon_win_;
+#endif
+  // How often to display a debug image.
+  int debug_interval_;
+  // Iteration at which the last checkpoint was dumped.
+  int checkpoint_iteration_;
+  // Basename of files to save best models to.
+  std::string model_base_;
+  // Checkpoint filename.
+  std::string checkpoint_name_;
+  // Training data.
+  bool randomly_rotate_;
+  DocumentCache training_data_;
+  // Name to use when saving best_trainer_.
+  std::string best_model_name_;
+  // Number of available training stages.
+  int num_training_stages_;
+
+  // ===Serialized data to ensure that a restart produces the same results.===
+  // These members are only serialized when serialize_amount != LIGHT.
+  // Best error rate so far.
+  double best_error_rate_;
+  // Snapshot of all error rates at best_iteration_.
+  double best_error_rates_[ET_COUNT];
+  // Iteration of best_error_rate_.
+  int best_iteration_;
+  // Worst error rate since best_error_rate_.
+  double worst_error_rate_;
+  // Snapshot of all error rates at worst_iteration_.
+  double worst_error_rates_[ET_COUNT];
+  // Iteration of worst_error_rate_.
+  int worst_iteration_;
+  // Iteration at which the process will be thought stalled.
+  int stall_iteration_;
+  // Saved recognition models for computing test error for graph points.
+  std::vector<char> best_model_data_;
+  std::vector<char> worst_model_data_;
+  // Saved trainer for reverting back to last known best.
+  std::vector<char> best_trainer_;
+  // A subsidiary trainer running with a different learning rate until either
+  // *this or sub_trainer_ hits a new best.
+  std::unique_ptr<LSTMTrainer> sub_trainer_;
+  // Error rate at which last best model was dumped.
+  float error_rate_of_last_saved_best_;
+  // Current stage of training.
+  int training_stage_;
+  // History of best error rate against iteration. Used for computing the
+  // number of steps to each 2% improvement.
+  std::vector<double> best_error_history_;
+  std::vector<int32_t> best_error_iterations_;
+  // Number of iterations since the best_error_rate_ was 2% more than it is now.
+  int32_t improvement_steps_;
+  // Number of iterations that yielded a non-zero delta error and thus provided
+  // significant learning. learning_iteration_ <= training_iteration_.
+  // learning_iteration_ is used to measure rate of learning progress.
+  int learning_iteration_;
+  // Saved value of sample_iteration_ before looking for the next sample.
+  int prev_sample_iteration_;
+  // How often to include a PERFECT training sample in backprop.
+  // A PERFECT training sample is used if the current
+  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
+  // so with perfect_delay_ == 0, all samples are used, and with
+  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
+  int perfect_delay_;
+  // Value of training_iteration_ at which the last PERFECT training sample
+  // was used in back prop.
+  int last_perfect_training_iteration_;
+  // Rolling buffers storing recent training errors are indexed by
+  // training_iteration % kRollingBufferSize_.
+  static const int kRollingBufferSize_ = 1000;
+  std::vector<double> error_buffers_[ET_COUNT];
+  // Rounded mean percent trailing training errors in the buffers.
+  double error_rates_[ET_COUNT]; // RMS training error.
+  // Traineddata file with optional dawgs + UNICHARSET and recoder.
+  TessdataManager mgr_;
+};
+
+} // namespace tesseract.
+
+#endif // TESSERACT_LSTM_LSTMTRAINER_H_