Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/lstmrecognizer.h @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstmrecognizer.h | |
| 3 // Description: Top-level line recognizer class for LSTM-based networks. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_ | |
| 19 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_ | |
| 20 | |
| 21 #include "ccutil.h" | |
| 22 #include "helpers.h" | |
| 23 #include "matrix.h" | |
| 24 #include "network.h" | |
| 25 #include "networkscratch.h" | |
| 26 #include "params.h" | |
| 27 #include "recodebeam.h" | |
| 28 #include "series.h" | |
| 29 #include "unicharcompress.h" | |
| 30 | |
| 31 class BLOB_CHOICE_IT; | |
| 32 struct Pix; | |
| 33 class ROW_RES; | |
| 34 class ScrollView; | |
| 35 class TBOX; | |
| 36 class WERD_RES; | |
| 37 | |
| 38 namespace tesseract { | |
| 39 | |
| 40 class Dict; | |
| 41 class ImageData; | |
| 42 | |
| 43 // Enum indicating training mode control flags. | |
| 44 enum TrainingFlags { | |
| 45 TF_INT_MODE = 1, | |
| 46 TF_COMPRESS_UNICHARSET = 64, | |
| 47 }; | |
| 48 | |
| 49 // Top-level line recognizer class for LSTM-based networks. | |
| 50 // Note that a sub-class, LSTMTrainer is used for training. | |
| 51 class TESS_API LSTMRecognizer { | |
| 52 public: | |
| 53 LSTMRecognizer(); | |
| 54 LSTMRecognizer(const std::string &language_data_path_prefix); | |
| 55 ~LSTMRecognizer(); | |
| 56 | |
| 57 int NumOutputs() const { | |
| 58 return network_->NumOutputs(); | |
| 59 } | |
| 60 | |
| 61 // Return the training iterations. | |
| 62 int training_iteration() const { | |
| 63 return training_iteration_; | |
| 64 } | |
| 65 | |
| 66 // Return the sample iterations. | |
| 67 int sample_iteration() const { | |
| 68 return sample_iteration_; | |
| 69 } | |
| 70 | |
| 71 // Return the learning rate. | |
| 72 float learning_rate() const { | |
| 73 return learning_rate_; | |
| 74 } | |
| 75 | |
| 76 LossType OutputLossType() const { | |
| 77 if (network_ == nullptr) { | |
| 78 return LT_NONE; | |
| 79 } | |
| 80 StaticShape shape; | |
| 81 shape = network_->OutputShape(shape); | |
| 82 return shape.loss_type(); | |
| 83 } | |
| 84 bool SimpleTextOutput() const { | |
| 85 return OutputLossType() == LT_SOFTMAX; | |
| 86 } | |
| 87 bool IsIntMode() const { | |
| 88 return (training_flags_ & TF_INT_MODE) != 0; | |
| 89 } | |
| 90 // True if recoder_ is active to re-encode text to a smaller space. | |
| 91 bool IsRecoding() const { | |
| 92 return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0; | |
| 93 } | |
| 94 // Returns true if the network is a TensorFlow network. | |
| 95 bool IsTensorFlow() const { | |
| 96 return network_->type() == NT_TENSORFLOW; | |
| 97 } | |
| 98 // Returns a vector of layer ids that can be passed to other layer functions | |
| 99 // to access a specific layer. | |
| 100 std::vector<std::string> EnumerateLayers() const { | |
| 101 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 102 auto *series = static_cast<Series *>(network_); | |
| 103 std::vector<std::string> layers; | |
| 104 series->EnumerateLayers(nullptr, layers); | |
| 105 return layers; | |
| 106 } | |
| 107 // Returns a specific layer from its id (from EnumerateLayers). | |
| 108 Network *GetLayer(const std::string &id) const { | |
| 109 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 110 ASSERT_HOST(id.length() > 1 && id[0] == ':'); | |
| 111 auto *series = static_cast<Series *>(network_); | |
| 112 return series->GetLayer(&id[1]); | |
| 113 } | |
| 114 // Returns the learning rate of the layer from its id. | |
| 115 float GetLayerLearningRate(const std::string &id) const { | |
| 116 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 117 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { | |
| 118 ASSERT_HOST(id.length() > 1 && id[0] == ':'); | |
| 119 auto *series = static_cast<Series *>(network_); | |
| 120 return series->LayerLearningRate(&id[1]); | |
| 121 } else { | |
| 122 return learning_rate_; | |
| 123 } | |
| 124 } | |
| 125 | |
| 126 // Return the network string. | |
| 127 const char *GetNetwork() const { | |
| 128 return network_str_.c_str(); | |
| 129 } | |
| 130 | |
| 131 // Return the adam beta. | |
| 132 float GetAdamBeta() const { | |
| 133 return adam_beta_; | |
| 134 } | |
| 135 | |
| 136 // Return the momentum. | |
| 137 float GetMomentum() const { | |
| 138 return momentum_; | |
| 139 } | |
| 140 | |
| 141 // Multiplies the all the learning rate(s) by the given factor. | |
| 142 void ScaleLearningRate(double factor) { | |
| 143 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 144 learning_rate_ *= factor; | |
| 145 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { | |
| 146 std::vector<std::string> layers = EnumerateLayers(); | |
| 147 for (auto &layer : layers) { | |
| 148 ScaleLayerLearningRate(layer, factor); | |
| 149 } | |
| 150 } | |
| 151 } | |
| 152 // Multiplies the learning rate of the layer with id, by the given factor. | |
| 153 void ScaleLayerLearningRate(const std::string &id, double factor) { | |
| 154 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 155 ASSERT_HOST(id.length() > 1 && id[0] == ':'); | |
| 156 auto *series = static_cast<Series *>(network_); | |
| 157 series->ScaleLayerLearningRate(&id[1], factor); | |
| 158 } | |
| 159 | |
| 160 // Set the all the learning rate(s) to the given value. | |
| 161 void SetLearningRate(float learning_rate) | |
| 162 { | |
| 163 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 164 learning_rate_ = learning_rate; | |
| 165 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { | |
| 166 for (auto &id : EnumerateLayers()) { | |
| 167 SetLayerLearningRate(id, learning_rate); | |
| 168 } | |
| 169 } | |
| 170 } | |
| 171 // Set the learning rate of the layer with id, by the given value. | |
| 172 void SetLayerLearningRate(const std::string &id, float learning_rate) | |
| 173 { | |
| 174 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); | |
| 175 ASSERT_HOST(id.length() > 1 && id[0] == ':'); | |
| 176 auto *series = static_cast<Series *>(network_); | |
| 177 series->SetLayerLearningRate(&id[1], learning_rate); | |
| 178 } | |
| 179 | |
| 180 // Converts the network to int if not already. | |
| 181 void ConvertToInt() { | |
| 182 if ((training_flags_ & TF_INT_MODE) == 0) { | |
| 183 network_->ConvertToInt(); | |
| 184 training_flags_ |= TF_INT_MODE; | |
| 185 } | |
| 186 } | |
| 187 | |
| 188 // Provides access to the UNICHARSET that this classifier works with. | |
| 189 const UNICHARSET &GetUnicharset() const { | |
| 190 return ccutil_.unicharset; | |
| 191 } | |
| 192 UNICHARSET &GetUnicharset() { | |
| 193 return ccutil_.unicharset; | |
| 194 } | |
| 195 // Provides access to the UnicharCompress that this classifier works with. | |
| 196 const UnicharCompress &GetRecoder() const { | |
| 197 return recoder_; | |
| 198 } | |
| 199 // Provides access to the Dict that this classifier works with. | |
| 200 const Dict *GetDict() const { | |
| 201 return dict_; | |
| 202 } | |
| 203 Dict *GetDict() { | |
| 204 return dict_; | |
| 205 } | |
| 206 // Sets the sample iteration to the given value. The sample_iteration_ | |
| 207 // determines the seed for the random number generator. The training | |
| 208 // iteration is incremented only by a successful training iteration. | |
| 209 void SetIteration(int iteration) { | |
| 210 sample_iteration_ = iteration; | |
| 211 } | |
| 212 // Accessors for textline image normalization. | |
| 213 int NumInputs() const { | |
| 214 return network_->NumInputs(); | |
| 215 } | |
| 216 | |
| 217 // Return the null char index. | |
| 218 int null_char() const { | |
| 219 return null_char_; | |
| 220 } | |
| 221 | |
| 222 // Loads a model from mgr, including the dictionary only if lang is not null. | |
| 223 bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr); | |
| 224 | |
| 225 // Writes to the given file. Returns false in case of error. | |
| 226 // If mgr contains a unicharset and recoder, then they are not encoded to fp. | |
| 227 bool Serialize(const TessdataManager *mgr, TFile *fp) const; | |
| 228 // Reads from the given file. Returns false in case of error. | |
| 229 // If mgr contains a unicharset and recoder, then they are taken from there, | |
| 230 // otherwise, they are part of the serialization in fp. | |
| 231 bool DeSerialize(const TessdataManager *mgr, TFile *fp); | |
| 232 // Loads the charsets from mgr. | |
| 233 bool LoadCharsets(const TessdataManager *mgr); | |
| 234 // Loads the Recoder. | |
| 235 bool LoadRecoder(TFile *fp); | |
| 236 // Loads the dictionary if possible from the traineddata file. | |
| 237 // Prints a warning message, and returns false but otherwise fails silently | |
| 238 // and continues to work without it if loading fails. | |
| 239 // Note that dictionary load is independent from DeSerialize, but dependent | |
| 240 // on the unicharset matching. This enables training to deserialize a model | |
| 241 // from checkpoint or restore without having to go back and reload the | |
| 242 // dictionary. | |
| 243 bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr); | |
| 244 | |
| 245 // Recognizes the line image, contained within image_data, returning the | |
| 246 // recognized tesseract WERD_RES for the words. | |
| 247 // If invert_threshold > 0, tries inverted as well if the normal | |
| 248 // interpretation doesn't produce a result which at least reaches | |
| 249 // that threshold. The line_box is used for computing the | |
| 250 // box_word in the output words. worst_dict_cert is the worst certainty that | |
| 251 // will be used in a dictionary word. | |
| 252 void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, | |
| 253 const TBOX &line_box, PointerVector<WERD_RES> *words, int lstm_choice_mode = 0, | |
| 254 int lstm_choice_amount = 5); | |
| 255 | |
| 256 // Helper computes min and mean best results in the output. | |
| 257 void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd); | |
| 258 // Recognizes the image_data, returning the labels, | |
| 259 // scores, and corresponding pairs of start, end x-coords in coords. | |
| 260 // Returned in scale_factor is the reduction factor | |
| 261 // between the image and the output coords, for computing bounding boxes. | |
| 262 // If re_invert is true, the input is inverted back to its original | |
| 263 // photometric interpretation if inversion is attempted but fails to | |
| 264 // improve the results. This ensures that outputs contains the correct | |
| 265 // forward outputs for the best photometric interpretation. | |
| 266 // inputs is filled with the used inputs to the network. | |
| 267 bool RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, bool re_invert, | |
| 268 bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs); | |
| 269 | |
| 270 // Converts an array of labels to utf-8, whether or not the labels are | |
| 271 // augmented with character boundaries. | |
| 272 std::string DecodeLabels(const std::vector<int> &labels); | |
| 273 | |
| 274 // Displays the forward results in a window with the characters and | |
| 275 // boundaries as determined by the labels and label_coords. | |
| 276 void DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels, | |
| 277 const std::vector<int> &label_coords, const char *window_name, | |
| 278 ScrollView **window); | |
| 279 // Converts the network output to a sequence of labels. Outputs labels, scores | |
| 280 // and start xcoords of each char, and each null_char_, with an additional | |
| 281 // final xcoord for the end of the output. | |
| 282 // The conversion method is determined by internal state. | |
| 283 void LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels, | |
| 284 std::vector<int> *xcoords); | |
| 285 | |
| 286 protected: | |
| 287 // Sets the random seed from the sample_iteration_; | |
| 288 void SetRandomSeed() { | |
| 289 int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001; | |
| 290 randomizer_.set_seed(seed); | |
| 291 randomizer_.IntRand(); | |
| 292 } | |
| 293 | |
| 294 // Displays the labels and cuts at the corresponding xcoords. | |
| 295 // Size of labels should match xcoords. | |
| 296 void DisplayLSTMOutput(const std::vector<int> &labels, const std::vector<int> &xcoords, | |
| 297 int height, ScrollView *window); | |
| 298 | |
| 299 // Prints debug output detailing the activation path that is implied by the | |
| 300 // xcoords. | |
| 301 void DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels, | |
| 302 const std::vector<int> &xcoords); | |
| 303 | |
| 304 // Prints debug output detailing activations and 2nd choice over a range | |
| 305 // of positions. | |
| 306 void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, | |
| 307 int x_start, int x_end); | |
| 308 | |
| 309 // As LabelsViaCTC except that this function constructs the best path that | |
| 310 // contains only legal sequences of subcodes for recoder_. | |
| 311 void LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels, | |
| 312 std::vector<int> *xcoords); | |
| 313 // Converts the network output to a sequence of labels, with scores, using | |
| 314 // the simple character model (each position is a char, and the null_char_ is | |
| 315 // mainly intended for tail padding.) | |
| 316 void LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels, | |
| 317 std::vector<int> *xcoords); | |
| 318 | |
| 319 // Returns a string corresponding to the label starting at start. Sets *end | |
| 320 // to the next start and if non-null, *decoded to the unichar id. | |
| 321 const char *DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, int *decoded); | |
| 322 | |
| 323 // Returns a string corresponding to a given single label id, falling back to | |
| 324 // a default of ".." for part of a multi-label unichar-id. | |
| 325 const char *DecodeSingleLabel(int label); | |
| 326 | |
| 327 protected: | |
| 328 // The network hierarchy. | |
| 329 Network *network_; | |
| 330 // The unicharset. Only the unicharset element is serialized. | |
| 331 // Has to be a CCUtil, so Dict can point to it. | |
| 332 CCUtil ccutil_; | |
| 333 // For backward compatibility, recoder_ is serialized iff | |
| 334 // training_flags_ & TF_COMPRESS_UNICHARSET. | |
| 335 // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset. | |
| 336 UnicharCompress recoder_; | |
| 337 | |
| 338 // ==Training parameters that are serialized to provide a record of them.== | |
| 339 std::string network_str_; | |
| 340 // Flags used to determine the training method of the network. | |
| 341 // See enum TrainingFlags above. | |
| 342 int32_t training_flags_; | |
| 343 // Number of actual backward training steps used. | |
| 344 int32_t training_iteration_; | |
| 345 // Index into training sample set. sample_iteration >= training_iteration_. | |
| 346 int32_t sample_iteration_; | |
| 347 // Index in softmax of null character. May take the value UNICHAR_BROKEN or | |
| 348 // ccutil_.unicharset.size(). | |
| 349 int32_t null_char_; | |
| 350 // Learning rate and momentum multipliers of deltas in backprop. | |
| 351 float learning_rate_; | |
| 352 float momentum_; | |
| 353 // Smoothing factor for 2nd moment of gradients. | |
| 354 float adam_beta_; | |
| 355 | |
| 356 // === NOT SERIALIZED. | |
| 357 TRand randomizer_; | |
| 358 NetworkScratch scratch_space_; | |
| 359 // Language model (optional) to use with the beam search. | |
| 360 Dict *dict_; | |
| 361 // Beam search held between uses to optimize memory allocation/use. | |
| 362 RecodeBeamSearch *search_; | |
| 363 | |
| 364 // == Debugging parameters.== | |
| 365 // Recognition debug display window. | |
| 366 ScrollView *debug_win_; | |
| 367 }; | |
| 368 | |
| 369 } // namespace tesseract. | |
| 370 | |
| 371 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_ |
