Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/training/lstmtraining.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstmtraining.cpp | |
| 3 // Description: Training program for LSTM-based networks. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #include <cerrno> | |
| 19 #include <locale> // for std::locale::classic | |
| 20 #if defined(__USE_GNU) | |
| 21 # include <cfenv> // for feenableexcept | |
| 22 #endif | |
| 23 #include "commontraining.h" | |
| 24 #include "fileio.h" // for LoadFileLinesToStrings | |
| 25 #include "lstmtester.h" | |
| 26 #include "lstmtrainer.h" | |
| 27 #include "params.h" | |
| 28 #include "tprintf.h" | |
| 29 #include "unicharset_training_utils.h" | |
| 30 | |
| 31 using namespace tesseract; | |
| 32 | |
| 33 static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment."); | |
| 34 static STRING_PARAM_FLAG(net_spec, "", "Network specification"); | |
| 35 static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior."); | |
| 36 static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples between perfect ones."); | |
| 37 static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent."); | |
| 38 static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights."); | |
| 39 static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas."); | |
| 40 static BOOL_PARAM_FLAG(reset_learning_rate, false, | |
| 41 "Resets all stored learning rates to the value specified by --learning_rate."); | |
| 42 static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas."); | |
| 43 static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas."); | |
| 44 static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images."); | |
| 45 static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend"); | |
| 46 static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models"); | |
| 47 static STRING_PARAM_FLAG(train_listfile, "", | |
| 48 "File listing training files in lstmf training format."); | |
| 49 static STRING_PARAM_FLAG(eval_listfile, "", "File listing eval files in lstmf training format."); | |
| 50 #if defined(__USE_GNU) | |
| 51 static BOOL_PARAM_FLAG(debug_float, false, "Raise error on certain float errors."); | |
| 52 #endif | |
| 53 static BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model."); | |
| 54 static BOOL_PARAM_FLAG(convert_to_int, false, "Convert the recognition model to an integer model."); | |
| 55 static BOOL_PARAM_FLAG(sequential_training, false, | |
| 56 "Use the training files sequentially instead of round-robin."); | |
| 57 static INT_PARAM_FLAG(append_index, -1, | |
| 58 "Index in continue_from Network at which to" | |
| 59 " attach the new network defined by net_spec"); | |
| 60 static BOOL_PARAM_FLAG(debug_network, false, "Get info on distribution of weight values"); | |
| 61 static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations"); | |
| 62 static STRING_PARAM_FLAG(traineddata, "", "Combined Dawgs/Unicharset/Recoder for language model"); | |
| 63 static STRING_PARAM_FLAG(old_traineddata, "", | |
| 64 "When changing the character set, this specifies the old" | |
| 65 " character set that is to be replaced"); | |
| 66 static BOOL_PARAM_FLAG(randomly_rotate, false, | |
| 67 "Train OSD and randomly turn training samples upside-down"); | |
| 68 | |
| 69 // Number of training images to train between calls to MaintainCheckpoints. | |
| 70 const int kNumPagesPerBatch = 100; | |
| 71 | |
| 72 // Apart from command-line flags, input is a collection of lstmf files, that | |
| 73 // were previously created using tesseract with the lstm.train config file. | |
| 74 // The program iterates over the inputs, feeding the data to the network, | |
| 75 // until the error rate reaches a specified target or max_iterations is reached. | |
| 76 int main(int argc, char **argv) { | |
| 77 tesseract::CheckSharedLibraryVersion(); | |
| 78 ParseArguments(&argc, &argv); | |
| 79 #if defined(__USE_GNU) | |
| 80 if (FLAGS_debug_float) { | |
| 81 // Raise SIGFPE for unwanted floating point calculations. | |
| 82 feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID); | |
| 83 } | |
| 84 #endif | |
| 85 if (FLAGS_model_output.empty()) { | |
| 86 tprintf("Must provide a --model_output!\n"); | |
| 87 return EXIT_FAILURE; | |
| 88 } | |
| 89 if (FLAGS_traineddata.empty()) { | |
| 90 tprintf("Must provide a --traineddata see training documentation\n"); | |
| 91 return EXIT_FAILURE; | |
| 92 } | |
| 93 | |
| 94 // Check write permissions. | |
| 95 std::string test_file = FLAGS_model_output; | |
| 96 test_file += "_wtest"; | |
| 97 FILE *f = fopen(test_file.c_str(), "wb"); | |
| 98 if (f != nullptr) { | |
| 99 fclose(f); | |
| 100 if (remove(test_file.c_str()) != 0) { | |
| 101 tprintf("Error, failed to remove %s: %s\n", test_file.c_str(), strerror(errno)); | |
| 102 return EXIT_FAILURE; | |
| 103 } | |
| 104 } else { | |
| 105 tprintf("Error, model output cannot be written: %s\n", strerror(errno)); | |
| 106 return EXIT_FAILURE; | |
| 107 } | |
| 108 | |
| 109 // Setup the trainer. | |
| 110 std::string checkpoint_file = FLAGS_model_output; | |
| 111 checkpoint_file += "_checkpoint"; | |
| 112 std::string checkpoint_bak = checkpoint_file + ".bak"; | |
| 113 tesseract::LSTMTrainer trainer(FLAGS_model_output, checkpoint_file, | |
| 114 FLAGS_debug_interval, | |
| 115 static_cast<int64_t>(FLAGS_max_image_MB) * 1048576); | |
| 116 if (!trainer.InitCharSet(FLAGS_traineddata.c_str())) { | |
| 117 tprintf("Error, failed to read %s\n", FLAGS_traineddata.c_str()); | |
| 118 return EXIT_FAILURE; | |
| 119 } | |
| 120 | |
| 121 // Reading something from an existing model doesn't require many flags, | |
| 122 // so do it now and exit. | |
| 123 if (FLAGS_stop_training || FLAGS_debug_network) { | |
| 124 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) { | |
| 125 tprintf("Failed to read continue from: %s\n", FLAGS_continue_from.c_str()); | |
| 126 return EXIT_FAILURE; | |
| 127 } | |
| 128 if (FLAGS_debug_network) { | |
| 129 trainer.DebugNetwork(); | |
| 130 } else { | |
| 131 if (FLAGS_convert_to_int) { | |
| 132 trainer.ConvertToInt(); | |
| 133 } | |
| 134 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) { | |
| 135 tprintf("Failed to write recognition model : %s\n", FLAGS_model_output.c_str()); | |
| 136 } | |
| 137 } | |
| 138 return EXIT_SUCCESS; | |
| 139 } | |
| 140 | |
| 141 // Get the list of files to process. | |
| 142 if (FLAGS_train_listfile.empty()) { | |
| 143 tprintf("Must supply a list of training filenames! --train_listfile\n"); | |
| 144 return EXIT_FAILURE; | |
| 145 } | |
| 146 std::vector<std::string> filenames; | |
| 147 if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(), &filenames)) { | |
| 148 tprintf("Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str()); | |
| 149 return EXIT_FAILURE; | |
| 150 } | |
| 151 | |
| 152 // Checkpoints always take priority if they are available. | |
| 153 if (trainer.TryLoadingCheckpoint(checkpoint_file.c_str(), nullptr) || | |
| 154 trainer.TryLoadingCheckpoint(checkpoint_bak.c_str(), nullptr)) { | |
| 155 tprintf("Successfully restored trainer from %s\n", checkpoint_file.c_str()); | |
| 156 } else { | |
| 157 if (!FLAGS_continue_from.empty()) { | |
| 158 // Load a past model file to improve upon. | |
| 159 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), | |
| 160 FLAGS_append_index >= 0 ? FLAGS_continue_from.c_str() | |
| 161 : FLAGS_old_traineddata.c_str())) { | |
| 162 tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str()); | |
| 163 return EXIT_FAILURE; | |
| 164 } | |
| 165 tprintf("Continuing from %s\n", FLAGS_continue_from.c_str()); | |
| 166 if (FLAGS_reset_learning_rate) { | |
| 167 trainer.SetLearningRate(FLAGS_learning_rate); | |
| 168 tprintf("Set learning rate to %f\n", static_cast<float>(FLAGS_learning_rate)); | |
| 169 } | |
| 170 trainer.InitIterations(); | |
| 171 } | |
| 172 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) { | |
| 173 if (FLAGS_append_index >= 0) { | |
| 174 tprintf("Appending a new network to an old one!!"); | |
| 175 if (FLAGS_continue_from.empty()) { | |
| 176 tprintf("Must set --continue_from for appending!\n"); | |
| 177 return EXIT_FAILURE; | |
| 178 } | |
| 179 } | |
| 180 // We are initializing from scratch. | |
| 181 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode, | |
| 182 FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum, | |
| 183 FLAGS_adam_beta)) { | |
| 184 tprintf("Failed to create network from spec: %s\n", FLAGS_net_spec.c_str()); | |
| 185 return EXIT_FAILURE; | |
| 186 } | |
| 187 trainer.set_perfect_delay(FLAGS_perfect_sample_delay); | |
| 188 } | |
| 189 } | |
| 190 if (!trainer.LoadAllTrainingData( | |
| 191 filenames, | |
| 192 FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL : tesseract::CS_ROUND_ROBIN, | |
| 193 FLAGS_randomly_rotate)) { | |
| 194 tprintf("Load of images failed!!\n"); | |
| 195 return EXIT_FAILURE; | |
| 196 } | |
| 197 | |
| 198 tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) * 1048576); | |
| 199 tesseract::TestCallback tester_callback = nullptr; | |
| 200 if (!FLAGS_eval_listfile.empty()) { | |
| 201 using namespace std::placeholders; // for _1, _2, _3... | |
| 202 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) { | |
| 203 tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str()); | |
| 204 return EXIT_FAILURE; | |
| 205 } | |
| 206 tester_callback = std::bind(&tesseract::LSTMTester::RunEvalAsync, &tester, _1, _2, _3, _4); | |
| 207 } | |
| 208 | |
| 209 int max_iterations = FLAGS_max_iterations; | |
| 210 if (max_iterations < 0) { | |
| 211 // A negative value is interpreted as epochs | |
| 212 max_iterations = filenames.size() * (-max_iterations); | |
| 213 } else if (max_iterations == 0) { | |
| 214 // "Infinite" iterations. | |
| 215 max_iterations = INT_MAX; | |
| 216 } | |
| 217 | |
| 218 do { | |
| 219 // Train a few. | |
| 220 int iteration = trainer.training_iteration(); | |
| 221 for (int target_iteration = iteration + kNumPagesPerBatch; | |
| 222 iteration < target_iteration && iteration < max_iterations; | |
| 223 iteration = trainer.training_iteration()) { | |
| 224 trainer.TrainOnLine(&trainer, false); | |
| 225 } | |
| 226 std::stringstream log_str; | |
| 227 log_str.imbue(std::locale::classic()); | |
| 228 trainer.MaintainCheckpoints(tester_callback, log_str); | |
| 229 tprintf("%s\n", log_str.str().c_str()); | |
| 230 } while (trainer.best_error_rate() > FLAGS_target_error_rate && | |
| 231 (trainer.training_iteration() < max_iterations)); | |
| 232 tprintf("Finished! Selected model with minimal training error rate (BCER) = %g\n", | |
| 233 trainer.best_error_rate()); | |
| 234 return EXIT_SUCCESS; | |
| 235 } /* main */ |
