Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/training/unicharset/lstmtrainer.h @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstmtrainer.h | |
| 3 // Description: Top-level line trainer class for LSTM-based networks. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_ | |
| 19 #define TESSERACT_LSTM_LSTMTRAINER_H_ | |
| 20 | |
| 21 #include "export.h" | |
| 22 | |
| 23 #include "imagedata.h" // for DocumentCache | |
| 24 #include "lstmrecognizer.h" | |
| 25 #include "rect.h" | |
| 26 | |
| 27 #include <functional> // for std::function | |
| 28 #include <sstream> // for std::stringstream | |
| 29 | |
| 30 namespace tesseract { | |
| 31 | |
| 32 class LSTM; | |
| 33 class LSTMTester; | |
| 34 class LSTMTrainer; | |
| 35 class Parallel; | |
| 36 class Reversed; | |
| 37 class Softmax; | |
| 38 class Series; | |
| 39 | |
| 40 // Enum for the types of errors that are counted. | |
| 41 enum ErrorTypes { | |
| 42 ET_RMS, // RMS activation error. | |
| 43 ET_DELTA, // Number of big errors in deltas. | |
| 44 ET_WORD_RECERR, // Output text string word recall error. | |
| 45 ET_CHAR_ERROR, // Output text string total char error. | |
| 46 ET_SKIP_RATIO, // Fraction of samples skipped. | |
| 47 ET_COUNT // For array sizing. | |
| 48 }; | |
| 49 | |
| 50 // Enum for the trainability_ flags. | |
| 51 enum Trainability { | |
| 52 TRAINABLE, // Non-zero delta error. | |
| 53 PERFECT, // Zero delta error. | |
| 54 UNENCODABLE, // Not trainable due to coding/alignment trouble. | |
| 55 HI_PRECISION_ERR, // Hi confidence disagreement. | |
| 56 NOT_BOXED, // Early in training and has no character boxes. | |
| 57 }; | |
| 58 | |
| 59 // Enum to define the amount of data to get serialized. | |
| 60 enum SerializeAmount { | |
| 61 LIGHT, // Minimal data for remote training. | |
| 62 NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_. | |
| 63 FULL, // All data including best_trainer_. | |
| 64 }; | |
| 65 | |
| 66 // Enum to indicate how the sub_trainer_ training went. | |
| 67 enum SubTrainerResult { | |
| 68 STR_NONE, // Did nothing as not good enough. | |
| 69 STR_UPDATED, // Subtrainer was updated, but didn't replace *this. | |
| 70 STR_REPLACED // Subtrainer replaced *this. | |
| 71 }; | |
| 72 | |
| 73 class LSTMTrainer; | |
| 74 // Function to compute and record error rates on some external test set(s). | |
| 75 // Args are: iteration, mean errors, model, training stage. | |
| 76 // Returns a string containing logging information about the tests. | |
| 77 using TestCallback = std::function<std::string(int, const double *, | |
| 78 const TessdataManager &, int)>; | |
| 79 | |
| 80 // Trainer class for LSTM networks. Most of the effort is in creating the | |
| 81 // ideal target outputs from the transcription. A box file is used if it is | |
| 82 // available, otherwise estimates of the char widths from the unicharset are | |
| 83 // used to guide a DP search for the best fit to the transcription. | |
| 84 class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer { | |
| 85 public: | |
| 86 LSTMTrainer(); | |
| 87 LSTMTrainer(const std::string &model_base, | |
| 88 const std::string &checkpoint_name, | |
| 89 int debug_interval, int64_t max_memory); | |
| 90 virtual ~LSTMTrainer(); | |
| 91 | |
| 92 // Tries to deserialize a trainer from the given file and silently returns | |
| 93 // false in case of failure. If old_traineddata is not null, then it is | |
| 94 // assumed that the character set is to be re-mapped from old_traineddata to | |
| 95 // the new, with consequent change in weight matrices etc. | |
| 96 bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata); | |
| 97 | |
| 98 // Initializes the character set encode/decode mechanism directly from a | |
| 99 // previously setup traineddata containing dawgs, UNICHARSET and | |
| 100 // UnicharCompress. Note: Call before InitNetwork! | |
| 101 bool InitCharSet(const std::string &traineddata_path) { | |
| 102 bool success = mgr_.Init(traineddata_path.c_str()); | |
| 103 if (success) { | |
| 104 InitCharSet(); | |
| 105 } | |
| 106 return success; | |
| 107 } | |
| 108 void InitCharSet(const TessdataManager &mgr) { | |
| 109 mgr_ = mgr; | |
| 110 InitCharSet(); | |
| 111 } | |
| 112 | |
| 113 // Initializes the trainer with a network_spec in the network description | |
| 114 // net_flags control network behavior according to the NetworkFlags enum. | |
| 115 // There isn't really much difference between them - only where the effects | |
| 116 // are implemented. | |
| 117 // For other args see NetworkBuilder::InitNetwork. | |
| 118 // Note: Be sure to call InitCharSet before InitNetwork! | |
| 119 bool InitNetwork(const char *network_spec, int append_index, int net_flags, | |
| 120 float weight_range, float learning_rate, float momentum, | |
| 121 float adam_beta); | |
| 122 // Resets all the iteration counters for fine tuning or training a head, | |
| 123 // where we want the error reporting to reset. | |
| 124 void InitIterations(); | |
| 125 | |
| 126 // Accessors. | |
| 127 double ActivationError() const { | |
| 128 return error_rates_[ET_DELTA]; | |
| 129 } | |
| 130 double CharError() const { | |
| 131 return error_rates_[ET_CHAR_ERROR]; | |
| 132 } | |
| 133 const double *error_rates() const { | |
| 134 return error_rates_; | |
| 135 } | |
| 136 double best_error_rate() const { | |
| 137 return best_error_rate_; | |
| 138 } | |
| 139 int best_iteration() const { | |
| 140 return best_iteration_; | |
| 141 } | |
| 142 int learning_iteration() const { | |
| 143 return learning_iteration_; | |
| 144 } | |
| 145 int32_t improvement_steps() const { | |
| 146 return improvement_steps_; | |
| 147 } | |
| 148 void set_perfect_delay(int delay) { | |
| 149 perfect_delay_ = delay; | |
| 150 } | |
| 151 const std::vector<char> &best_trainer() const { | |
| 152 return best_trainer_; | |
| 153 } | |
| 154 // Returns the error that was just calculated by PrepareForBackward. | |
| 155 double NewSingleError(ErrorTypes type) const { | |
| 156 return error_buffers_[type][training_iteration() % kRollingBufferSize_]; | |
| 157 } | |
| 158 // Returns the error that was just calculated by TrainOnLine. Since | |
| 159 // TrainOnLine rolls the error buffers, this is one further back than | |
| 160 // NewSingleError. | |
| 161 double LastSingleError(ErrorTypes type) const { | |
| 162 return error_buffers_[type] | |
| 163 [(training_iteration() + kRollingBufferSize_ - 1) % | |
| 164 kRollingBufferSize_]; | |
| 165 } | |
| 166 const DocumentCache &training_data() const { | |
| 167 return training_data_; | |
| 168 } | |
| 169 DocumentCache *mutable_training_data() { | |
| 170 return &training_data_; | |
| 171 } | |
| 172 | |
| 173 // If the training sample is usable, grid searches for the optimal | |
| 174 // dict_ratio/cert_offset, and returns the results in a string of space- | |
| 175 // separated triplets of ratio,offset=worderr. | |
| 176 Trainability GridSearchDictParams( | |
| 177 const ImageData *trainingdata, int iteration, double min_dict_ratio, | |
| 178 double dict_ratio_step, double max_dict_ratio, double min_cert_offset, | |
| 179 double cert_offset_step, double max_cert_offset, std::string &results); | |
| 180 | |
| 181 // Provides output on the distribution of weight values. | |
| 182 void DebugNetwork(); | |
| 183 | |
| 184 // Loads a set of lstmf files that were created using the lstm.train config to | |
| 185 // tesseract into memory ready for training. Returns false if nothing was | |
| 186 // loaded. | |
| 187 bool LoadAllTrainingData(const std::vector<std::string> &filenames, | |
| 188 CachingStrategy cache_strategy, | |
| 189 bool randomly_rotate); | |
| 190 | |
| 191 // Keeps track of best and locally worst error rate, using internally computed | |
| 192 // values. See MaintainCheckpointsSpecific for more detail. | |
| 193 bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg); | |
| 194 // Keeps track of best and locally worst error_rate (whatever it is) and | |
| 195 // launches tests using rec_model, when a new min or max is reached. | |
| 196 // Writes checkpoints using train_model at appropriate times and builds and | |
| 197 // returns a log message to indicate progress. Returns false if nothing | |
| 198 // interesting happened. | |
| 199 bool MaintainCheckpointsSpecific(int iteration, | |
| 200 const std::vector<char> *train_model, | |
| 201 const std::vector<char> *rec_model, | |
| 202 TestCallback tester, std::stringstream &log_msg); | |
| 203 // Builds a progress message with current error rates. | |
| 204 void PrepareLogMsg(std::stringstream &log_msg) const; | |
| 205 // Appends <intro_str> iteration learning_iteration()/training_iteration()/ | |
| 206 // sample_iteration() to the log_msg. | |
| 207 void LogIterations(const char *intro_str, std::stringstream &log_msg) const; | |
| 208 | |
| 209 // TODO(rays) Add curriculum learning. | |
| 210 // Returns true and increments the training_stage_ if the error rate has just | |
| 211 // passed through the given threshold for the first time. | |
| 212 bool TransitionTrainingStage(float error_threshold); | |
| 213 // Returns the current training stage. | |
| 214 int CurrentTrainingStage() const { | |
| 215 return training_stage_; | |
| 216 } | |
| 217 | |
| 218 // Writes to the given file. Returns false in case of error. | |
| 219 bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, | |
| 220 TFile *fp) const; | |
| 221 // Reads from the given file. Returns false in case of error. | |
| 222 bool DeSerialize(const TessdataManager *mgr, TFile *fp); | |
| 223 | |
| 224 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the | |
| 225 // learning rates (by scaling reduction, or layer specific, according to | |
| 226 // NF_LAYER_SPECIFIC_LR). | |
| 227 void StartSubtrainer(std::stringstream &log_msg); | |
| 228 // While the sub_trainer_ is behind the current training iteration and its | |
| 229 // training error is at least kSubTrainerMarginFraction better than the | |
| 230 // current training error, trains the sub_trainer_, and returns STR_UPDATED if | |
| 231 // it did anything. If it catches up, and has a better error rate than the | |
| 232 // current best, as well as a margin over the current error rate, then the | |
| 233 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is | |
| 234 // returned. STR_NONE is returned if the subtrainer wasn't good enough to | |
| 235 // receive any training iterations. | |
| 236 SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg); | |
| 237 // Reduces network learning rates, either for everything, or for layers | |
| 238 // independently, according to NF_LAYER_SPECIFIC_LR. | |
| 239 void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg); | |
| 240 // Considers reducing the learning rate independently for each layer down by | |
| 241 // factor(<1), or leaving it the same, by double-training the given number of | |
| 242 // samples and minimizing the amount of changing of sign of weight updates. | |
| 243 // Even if it looks like all weights should remain the same, an adjustment | |
| 244 // will be made to guarantee a different result when reverting to an old best. | |
| 245 // Returns the number of layer learning rates that were reduced. | |
| 246 int ReduceLayerLearningRates(TFloat factor, int num_samples, | |
| 247 LSTMTrainer *samples_trainer); | |
| 248 | |
| 249 // Converts the string to integer class labels, with appropriate null_char_s | |
| 250 // in between if not in SimpleTextOutput mode. Returns false on failure. | |
| 251 bool EncodeString(const std::string &str, std::vector<int> *labels) const { | |
| 252 return EncodeString(str, GetUnicharset(), | |
| 253 IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(), | |
| 254 null_char_, labels); | |
| 255 } | |
| 256 // Static version operates on supplied unicharset, encoder, simple_text. | |
| 257 static bool EncodeString(const std::string &str, const UNICHARSET &unicharset, | |
| 258 const UnicharCompress *recoder, bool simple_text, | |
| 259 int null_char, std::vector<int> *labels); | |
| 260 | |
| 261 // Performs forward-backward on the given trainingdata. | |
| 262 // Returns the sample that was used or nullptr if the next sample was deemed | |
| 263 // unusable. samples_trainer could be this or an alternative trainer that | |
| 264 // holds the training samples. | |
| 265 const ImageData *TrainOnLine(LSTMTrainer *samples_trainer, bool batch) { | |
| 266 int sample_index = sample_iteration(); | |
| 267 const ImageData *image = | |
| 268 samples_trainer->training_data_.GetPageBySerial(sample_index); | |
| 269 if (image != nullptr) { | |
| 270 Trainability trainable = TrainOnLine(image, batch); | |
| 271 if (trainable == UNENCODABLE || trainable == NOT_BOXED) { | |
| 272 return nullptr; // Sample was unusable. | |
| 273 } | |
| 274 } else { | |
| 275 ++sample_iteration_; | |
| 276 } | |
| 277 return image; | |
| 278 } | |
| 279 Trainability TrainOnLine(const ImageData *trainingdata, bool batch); | |
| 280 | |
| 281 // Prepares the ground truth, runs forward, and prepares the targets. | |
| 282 // Returns a Trainability enum to indicate the suitability of the sample. | |
| 283 Trainability PrepareForBackward(const ImageData *trainingdata, | |
| 284 NetworkIO *fwd_outputs, NetworkIO *targets); | |
| 285 | |
| 286 // Writes the trainer to memory, so that the current training state can be | |
| 287 // restored. *this must always be the master trainer that retains the only | |
| 288 // copy of the training data and language model. trainer is the model that is | |
| 289 // actually serialized. | |
| 290 bool SaveTrainingDump(SerializeAmount serialize_amount, | |
| 291 const LSTMTrainer &trainer, | |
| 292 std::vector<char> *data) const; | |
| 293 | |
| 294 // Reads previously saved trainer from memory. *this must always be the | |
| 295 // master trainer that retains the only copy of the training data and | |
| 296 // language model. trainer is the model that is restored. | |
| 297 bool ReadTrainingDump(const std::vector<char> &data, | |
| 298 LSTMTrainer &trainer) const { | |
| 299 if (data.empty()) { | |
| 300 return false; | |
| 301 } | |
| 302 return ReadSizedTrainingDump(&data[0], data.size(), trainer); | |
| 303 } | |
| 304 bool ReadSizedTrainingDump(const char *data, int size, | |
| 305 LSTMTrainer &trainer) const { | |
| 306 return trainer.ReadLocalTrainingDump(&mgr_, data, size); | |
| 307 } | |
| 308 // Restores the model to *this. | |
| 309 bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, | |
| 310 int size); | |
| 311 | |
| 312 // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump. | |
| 313 void SetupCheckpointInfo(); | |
| 314 | |
| 315 // Writes the full recognition traineddata to the given filename. | |
| 316 bool SaveTraineddata(const char *filename); | |
| 317 | |
| 318 // Writes the recognizer to memory, so that it can be used for testing later. | |
| 319 void SaveRecognitionDump(std::vector<char> *data) const; | |
| 320 | |
| 321 // Returns a suitable filename for a training dump, based on the model_base_, | |
| 322 // the iteration and the error rates. | |
| 323 std::string DumpFilename() const; | |
| 324 | |
| 325 // Fills the whole error buffer of the given type with the given value. | |
| 326 void FillErrorBuffer(double new_error, ErrorTypes type); | |
| 327 // Helper generates a map from each current recoder_ code (ie softmax index) | |
| 328 // to the corresponding old_recoder code, or -1 if there isn't one. | |
| 329 std::vector<int> MapRecoder(const UNICHARSET &old_chset, | |
| 330 const UnicharCompress &old_recoder) const; | |
| 331 | |
| 332 protected: | |
| 333 // Private version of InitCharSet above finishes the job after initializing | |
| 334 // the mgr_ data member. | |
| 335 void InitCharSet(); | |
| 336 // Helper computes and sets the null_char_. | |
| 337 void SetNullChar(); | |
| 338 | |
| 339 // Factored sub-constructor sets up reasonable default values. | |
| 340 void EmptyConstructor(); | |
| 341 | |
| 342 // Outputs the string and periodically displays the given network inputs | |
| 343 // as an image in the given window, and the corresponding labels at the | |
| 344 // corresponding x_starts. | |
| 345 // Returns false if the truth string is empty. | |
| 346 bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, | |
| 347 const NetworkIO &fwd_outputs, | |
| 348 const std::vector<int> &truth_labels, | |
| 349 const NetworkIO &outputs); | |
| 350 // Displays the network targets as line a line graph. | |
| 351 void DisplayTargets(const NetworkIO &targets, const char *window_name, | |
| 352 ScrollView **window); | |
| 353 | |
| 354 // Builds a no-compromises target where the first positions should be the | |
| 355 // truth labels and the rest is padded with the null_char_. | |
| 356 bool ComputeTextTargets(const NetworkIO &outputs, | |
| 357 const std::vector<int> &truth_labels, | |
| 358 NetworkIO *targets); | |
| 359 | |
| 360 // Builds a target using standard CTC. truth_labels should be pre-padded with | |
| 361 // nulls wherever desired. They don't have to be between all labels. | |
| 362 // outputs is input-output, as it gets clipped to minimum probability. | |
| 363 bool ComputeCTCTargets(const std::vector<int> &truth_labels, | |
| 364 NetworkIO *outputs, NetworkIO *targets); | |
| 365 | |
| 366 // Computes network errors, and stores the results in the rolling buffers, | |
| 367 // along with the supplied text_error. | |
| 368 // Returns the delta error of the current sample (not running average.) | |
| 369 double ComputeErrorRates(const NetworkIO &deltas, double char_error, | |
| 370 double word_error); | |
| 371 | |
| 372 // Computes the network activation RMS error rate. | |
| 373 double ComputeRMSError(const NetworkIO &deltas); | |
| 374 | |
| 375 // Computes network activation winner error rate. (Number of values that are | |
| 376 // in error by >= 0.5 divided by number of time-steps.) More closely related | |
| 377 // to final character error than RMS, but still directly calculable from | |
| 378 // just the deltas. Because of the binary nature of the targets, zero winner | |
| 379 // error is a sufficient but not necessary condition for zero char error. | |
| 380 double ComputeWinnerError(const NetworkIO &deltas); | |
| 381 | |
| 382 // Computes a very simple bag of chars char error rate. | |
| 383 double ComputeCharError(const std::vector<int> &truth_str, | |
| 384 const std::vector<int> &ocr_str); | |
| 385 // Computes a very simple bag of words word recall error rate. | |
| 386 // NOTE that this is destructive on both input strings. | |
| 387 double ComputeWordError(std::string *truth_str, std::string *ocr_str); | |
| 388 | |
| 389 // Updates the error buffer and corresponding mean of the given type with | |
| 390 // the new_error. | |
| 391 void UpdateErrorBuffer(double new_error, ErrorTypes type); | |
| 392 | |
| 393 // Rolls error buffers and reports the current means. | |
| 394 void RollErrorBuffers(); | |
| 395 | |
| 396 // Given that error_rate is either a new min or max, updates the best/worst | |
| 397 // error rates, and record of progress. | |
| 398 std::string UpdateErrorGraph(int iteration, double error_rate, | |
| 399 const std::vector<char> &model_data, | |
| 400 const TestCallback &tester); | |
| 401 | |
| 402 protected: | |
| 403 #ifndef GRAPHICS_DISABLED | |
| 404 // Alignment display window. | |
| 405 ScrollView *align_win_; | |
| 406 // CTC target display window. | |
| 407 ScrollView *target_win_; | |
| 408 // CTC output display window. | |
| 409 ScrollView *ctc_win_; | |
| 410 // Reconstructed image window. | |
| 411 ScrollView *recon_win_; | |
| 412 #endif | |
| 413 // How often to display a debug image. | |
| 414 int debug_interval_; | |
| 415 // Iteration at which the last checkpoint was dumped. | |
| 416 int checkpoint_iteration_; | |
| 417 // Basename of files to save best models to. | |
| 418 std::string model_base_; | |
| 419 // Checkpoint filename. | |
| 420 std::string checkpoint_name_; | |
| 421 // Training data. | |
| 422 bool randomly_rotate_; | |
| 423 DocumentCache training_data_; | |
| 424 // Name to use when saving best_trainer_. | |
| 425 std::string best_model_name_; | |
| 426 // Number of available training stages. | |
| 427 int num_training_stages_; | |
| 428 | |
| 429 // ===Serialized data to ensure that a restart produces the same results.=== | |
| 430 // These members are only serialized when serialize_amount != LIGHT. | |
| 431 // Best error rate so far. | |
| 432 double best_error_rate_; | |
| 433 // Snapshot of all error rates at best_iteration_. | |
| 434 double best_error_rates_[ET_COUNT]; | |
| 435 // Iteration of best_error_rate_. | |
| 436 int best_iteration_; | |
| 437 // Worst error rate since best_error_rate_. | |
| 438 double worst_error_rate_; | |
| 439 // Snapshot of all error rates at worst_iteration_. | |
| 440 double worst_error_rates_[ET_COUNT]; | |
| 441 // Iteration of worst_error_rate_. | |
| 442 int worst_iteration_; | |
| 443 // Iteration at which the process will be thought stalled. | |
| 444 int stall_iteration_; | |
| 445 // Saved recognition models for computing test error for graph points. | |
| 446 std::vector<char> best_model_data_; | |
| 447 std::vector<char> worst_model_data_; | |
| 448 // Saved trainer for reverting back to last known best. | |
| 449 std::vector<char> best_trainer_; | |
| 450 // A subsidiary trainer running with a different learning rate until either | |
| 451 // *this or sub_trainer_ hits a new best. | |
| 452 std::unique_ptr<LSTMTrainer> sub_trainer_; | |
| 453 // Error rate at which last best model was dumped. | |
| 454 float error_rate_of_last_saved_best_; | |
| 455 // Current stage of training. | |
| 456 int training_stage_; | |
| 457 // History of best error rate against iteration. Used for computing the | |
| 458 // number of steps to each 2% improvement. | |
| 459 std::vector<double> best_error_history_; | |
| 460 std::vector<int32_t> best_error_iterations_; | |
| 461 // Number of iterations since the best_error_rate_ was 2% more than it is now. | |
| 462 int32_t improvement_steps_; | |
| 463 // Number of iterations that yielded a non-zero delta error and thus provided | |
| 464 // significant learning. learning_iteration_ <= training_iteration_. | |
| 465 // learning_iteration_ is used to measure rate of learning progress. | |
| 466 int learning_iteration_; | |
| 467 // Saved value of sample_iteration_ before looking for the next sample. | |
| 468 int prev_sample_iteration_; | |
| 469 // How often to include a PERFECT training sample in backprop. | |
| 470 // A PERFECT training sample is used if the current | |
| 471 // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_, | |
| 472 // so with perfect_delay_ == 0, all samples are used, and with | |
| 473 // perfect_delay_ == 4, at most 1 in 5 samples will be perfect. | |
| 474 int perfect_delay_; | |
| 475 // Value of training_iteration_ at which the last PERFECT training sample | |
| 476 // was used in back prop. | |
| 477 int last_perfect_training_iteration_; | |
| 478 // Rolling buffers storing recent training errors are indexed by | |
| 479 // training_iteration % kRollingBufferSize_. | |
| 480 static const int kRollingBufferSize_ = 1000; | |
| 481 std::vector<double> error_buffers_[ET_COUNT]; | |
| 482 // Rounded mean percent trailing training errors in the buffers. | |
| 483 double error_rates_[ET_COUNT]; // RMS training error. | |
| 484 // Traineddata file with optional dawgs + UNICHARSET and recoder. | |
| 485 TessdataManager mgr_; | |
| 486 }; | |
| 487 | |
| 488 } // namespace tesseract. | |
| 489 | |
| 490 #endif // TESSERACT_LSTM_LSTMTRAINER_H_ |
