Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/network.h @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: network.h | |
| 3 // Description: Base class for neural network implementations. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #ifndef TESSERACT_LSTM_NETWORK_H_ | |
| 19 #define TESSERACT_LSTM_NETWORK_H_ | |
| 20 | |
| 21 #include "helpers.h" | |
| 22 #include "matrix.h" | |
| 23 #include "networkio.h" | |
| 24 #include "serialis.h" | |
| 25 #include "static_shape.h" | |
| 26 #include "tprintf.h" | |
| 27 | |
| 28 #include <cmath> | |
| 29 #include <cstdio> | |
| 30 | |
| 31 struct Pix; | |
| 32 | |
| 33 namespace tesseract { | |
| 34 | |
| 35 class ScrollView; | |
| 36 class TBOX; | |
| 37 class ImageData; | |
| 38 class NetworkScratch; | |
| 39 | |
| 40 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames. | |
| 41 enum NetworkType { | |
| 42 NT_NONE, // The naked base class. | |
| 43 NT_INPUT, // Inputs from an image. | |
| 44 // Plumbing networks combine other networks or rearrange the inputs. | |
| 45 NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood. | |
| 46 NT_MAXPOOL, // Chooses the max result from a rectangle. | |
| 47 NT_PARALLEL, // Runs networks in parallel. | |
| 48 NT_REPLICATED, // Runs identical networks in parallel. | |
| 49 NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel. | |
| 50 NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel. | |
| 51 NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel. | |
| 52 NT_SERIES, // Executes a sequence of layers. | |
| 53 NT_RECONFIG, // Scales the time/y size but makes the output deeper. | |
| 54 NT_XREVERSED, // Reverses the x direction of the inputs/outputs. | |
| 55 NT_YREVERSED, // Reverses the y-direction of the inputs/outputs. | |
| 56 NT_XYTRANSPOSE, // Transposes x and y (for just a single op). | |
| 57 // Functional networks actually calculate stuff. | |
| 58 NT_LSTM, // Long-Short-Term-Memory block. | |
| 59 NT_LSTM_SUMMARY, // LSTM that only keeps its last output. | |
| 60 NT_LOGISTIC, // Fully connected logistic nonlinearity. | |
| 61 NT_POSCLIP, // Fully connected rect lin version of logistic. | |
| 62 NT_SYMCLIP, // Fully connected rect lin version of tanh. | |
| 63 NT_TANH, // Fully connected with tanh nonlinearity. | |
| 64 NT_RELU, // Fully connected with rectifier nonlinearity. | |
| 65 NT_LINEAR, // Fully connected with no nonlinearity. | |
| 66 NT_SOFTMAX, // Softmax uses exponential normalization, with CTC. | |
| 67 NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC. | |
| 68 // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with | |
| 69 // the outputs fed back to the input of the LSTM at the next timestep. | |
| 70 // The ENCODED version binary encodes the softmax outputs, providing log2 of | |
| 71 // the number of outputs as additional inputs, and the other version just | |
| 72 // provides all the softmax outputs as additional inputs. | |
| 73 NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax. | |
| 74 NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax. | |
| 75 // A TensorFlow graph encapsulated as a Tesseract network. | |
| 76 NT_TENSORFLOW, | |
| 77 | |
| 78 NT_COUNT // Array size. | |
| 79 }; | |
| 80 | |
| 81 // Enum of Network behavior flags. Can in theory be set for each individual | |
| 82 // network element. | |
| 83 enum NetworkFlags { | |
| 84 // Network forward/backprop behavior. | |
| 85 NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer. | |
| 86 NF_ADAM = 128, // Weight-specific learning rate. | |
| 87 }; | |
| 88 | |
| 89 // State of training and desired state used in SetEnableTraining. | |
| 90 enum TrainingState { | |
| 91 // Valid states of training_. | |
| 92 TS_DISABLED, // Disabled permanently. | |
| 93 TS_ENABLED, // Enabled for backprop and to write a training dump. | |
| 94 // Re-enable from ANY disabled state. | |
| 95 TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump. | |
| 96 // Valid only for SetEnableTraining. | |
| 97 TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED. | |
| 98 }; | |
| 99 | |
| 100 // Base class for network types. Not quite an abstract base class, but almost. | |
| 101 // Most of the time no isolated Network exists, except prior to | |
| 102 // deserialization. | |
| 103 class TESS_API Network { | |
| 104 public: | |
| 105 Network(); | |
| 106 Network(NetworkType type, const std::string &name, int ni, int no); | |
| 107 virtual ~Network() = default; | |
| 108 | |
| 109 // Accessors. | |
| 110 NetworkType type() const { | |
| 111 return type_; | |
| 112 } | |
| 113 bool IsTraining() const { | |
| 114 return training_ == TS_ENABLED; | |
| 115 } | |
| 116 bool needs_to_backprop() const { | |
| 117 return needs_to_backprop_; | |
| 118 } | |
| 119 int num_weights() const { | |
| 120 return num_weights_; | |
| 121 } | |
| 122 int NumInputs() const { | |
| 123 return ni_; | |
| 124 } | |
| 125 int NumOutputs() const { | |
| 126 return no_; | |
| 127 } | |
| 128 // Returns the required shape input to the network. | |
| 129 virtual StaticShape InputShape() const { | |
| 130 StaticShape result; | |
| 131 return result; | |
| 132 } | |
| 133 // Returns the shape output from the network given an input shape (which may | |
| 134 // be partially unknown ie zero). | |
| 135 virtual StaticShape OutputShape(const StaticShape &input_shape) const { | |
| 136 StaticShape result(input_shape); | |
| 137 result.set_depth(no_); | |
| 138 return result; | |
| 139 } | |
| 140 const std::string &name() const { | |
| 141 return name_; | |
| 142 } | |
| 143 virtual std::string spec() const = 0; | |
| 144 bool TestFlag(NetworkFlags flag) const { | |
| 145 return (network_flags_ & flag) != 0; | |
| 146 } | |
| 147 | |
| 148 // Initialization and administrative functions that are mostly provided | |
| 149 // by Plumbing. | |
| 150 // Returns true if the given type is derived from Plumbing, and thus contains | |
| 151 // multiple sub-networks that can have their own learning rate. | |
| 152 virtual bool IsPlumbingType() const { | |
| 153 return false; | |
| 154 } | |
| 155 | |
| 156 // Suspends/Enables/Permanently disables training by setting the training_ | |
| 157 // flag. Serialize and DeSerialize only operate on the run-time data if state | |
| 158 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will | |
| 159 // temporarily disable layers in state TS_ENABLED, allowing a trainer to | |
| 160 // serialize as if it were a recognizer. | |
| 161 // TS_RE_ENABLE will re-enable layers that were previously in any disabled | |
| 162 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in | |
| 163 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a | |
| 164 // recognizer can be converted back to a trainer. | |
| 165 virtual void SetEnableTraining(TrainingState state); | |
| 166 | |
| 167 // Sets flags that control the action of the network. See NetworkFlags enum | |
| 168 // for bit values. | |
| 169 virtual void SetNetworkFlags(uint32_t flags); | |
| 170 | |
| 171 // Sets up the network for training. Initializes weights using weights of | |
| 172 // scale `range` picked according to the random number generator `randomizer`. | |
| 173 // Note that randomizer is a borrowed pointer that should outlive the network | |
| 174 // and should not be deleted by any of the networks. | |
| 175 // Returns the number of weights initialized. | |
| 176 virtual int InitWeights(float range, TRand *randomizer); | |
| 177 // Changes the number of outputs to the outside world to the size of the given | |
| 178 // code_map. Recursively searches the entire network for Softmax layers that | |
| 179 // have exactly old_no outputs, and operates only on those, leaving all others | |
| 180 // unchanged. This enables networks with multiple output layers to get all | |
| 181 // their softmaxes updated, but if an internal layer, uses one of those | |
| 182 // softmaxes for input, then the inputs will effectively be scrambled. | |
| 183 // TODO(rays) Fix this before any such network is implemented. | |
| 184 // The softmaxes are resized by copying the old weight matrix entries for each | |
| 185 // output from code_map[output] where non-negative, and uses the mean (over | |
| 186 // all outputs) of the existing weights for all outputs with negative code_map | |
| 187 // entries. Returns the new number of weights. | |
| 188 virtual int RemapOutputs([[maybe_unused]] int old_no, | |
| 189 [[maybe_unused]] const std::vector<int> &code_map) { | |
| 190 return 0; | |
| 191 } | |
| 192 | |
| 193 // Converts a float network to an int network. | |
| 194 virtual void ConvertToInt() {} | |
| 195 | |
| 196 // Provides a pointer to a TRand for any networks that care to use it. | |
| 197 // Note that randomizer is a borrowed pointer that should outlive the network | |
| 198 // and should not be deleted by any of the networks. | |
| 199 virtual void SetRandomizer(TRand *randomizer); | |
| 200 | |
| 201 // Sets needs_to_backprop_ to needs_backprop and returns true if | |
| 202 // needs_backprop || any weights in this network so the next layer forward | |
| 203 // can be told to produce backprop for this layer if needed. | |
| 204 virtual bool SetupNeedsBackprop(bool needs_backprop); | |
| 205 | |
| 206 // Returns the most recent reduction factor that the network applied to the | |
| 207 // time sequence. Assumes that any 2-d is already eliminated. Used for | |
| 208 // scaling bounding boxes of truth data and calculating result bounding boxes. | |
| 209 // WARNING: if GlobalMinimax is used to vary the scale, this will return | |
| 210 // the last used scale factor. Call it before any forward, and it will return | |
| 211 // the minimum scale factor of the paths through the GlobalMinimax. | |
| 212 virtual int XScaleFactor() const { | |
| 213 return 1; | |
| 214 } | |
| 215 | |
| 216 // Provides the (minimum) x scale factor to the network (of interest only to | |
| 217 // input units) so they can determine how to scale bounding boxes. | |
| 218 virtual void CacheXScaleFactor([[maybe_unused]] int factor) {} | |
| 219 | |
| 220 // Provides debug output on the weights. | |
| 221 virtual void DebugWeights() = 0; | |
| 222 | |
| 223 // Writes to the given file. Returns false in case of error. | |
| 224 // Should be overridden by subclasses, but called by their Serialize. | |
| 225 virtual bool Serialize(TFile *fp) const; | |
| 226 // Reads from the given file. Returns false in case of error. | |
| 227 // Should be overridden by subclasses, but NOT called by their DeSerialize. | |
| 228 virtual bool DeSerialize(TFile *fp) = 0; | |
| 229 | |
| 230 public: | |
| 231 // Updates the weights using the given learning rate, momentum and adam_beta. | |
| 232 // num_samples is used in the adam computation iff use_adam_ is true. | |
| 233 virtual void Update([[maybe_unused]] float learning_rate, | |
| 234 [[maybe_unused]] float momentum, | |
| 235 [[maybe_unused]] float adam_beta, | |
| 236 [[maybe_unused]] int num_samples) {} | |
| 237 // Sums the products of weight updates in *this and other, splitting into | |
| 238 // positive (same direction) in *same and negative (different direction) in | |
| 239 // *changed. | |
| 240 virtual void CountAlternators([[maybe_unused]] const Network &other, | |
| 241 [[maybe_unused]] TFloat *same, | |
| 242 [[maybe_unused]] TFloat *changed) const {} | |
| 243 | |
| 244 // Reads from the given file. Returns nullptr in case of error. | |
| 245 // Determines the type of the serialized class and calls its DeSerialize | |
| 246 // on a new object of the appropriate type, which is returned. | |
| 247 static Network *CreateFromFile(TFile *fp); | |
| 248 | |
| 249 // Runs forward propagation of activations on the input line. | |
| 250 // Note that input and output are both 2-d arrays. | |
| 251 // The 1st index is the time element. In a 1-d network, it might be the pixel | |
| 252 // position on the textline. In a 2-d network, the linearization is defined | |
| 253 // by the stride_map. (See networkio.h). | |
| 254 // The 2nd index of input is the network inputs/outputs, and the dimension | |
| 255 // of the input must match NumInputs() of this network. | |
| 256 // The output array will be resized as needed so that its 1st dimension is | |
| 257 // always equal to the number of output values, and its second dimension is | |
| 258 // always NumOutputs(). Note that all this detail is encapsulated away inside | |
| 259 // NetworkIO, as are the internals of the scratch memory space used by the | |
| 260 // network. See networkscratch.h for that. | |
| 261 // If input_transpose is not nullptr, then it contains the transpose of input, | |
| 262 // and the caller guarantees that it will still be valid on the next call to | |
| 263 // backward. The callee is therefore at liberty to save the pointer and | |
| 264 // reference it on a call to backward. This is a bit ugly, but it makes it | |
| 265 // possible for a replicating parallel to calculate the input transpose once | |
| 266 // instead of all the replicated networks having to do it. | |
| 267 virtual void Forward(bool debug, const NetworkIO &input, | |
| 268 const TransposedArray *input_transpose, | |
| 269 NetworkScratch *scratch, NetworkIO *output) = 0; | |
| 270 | |
| 271 // Runs backward propagation of errors on fwdX_deltas. | |
| 272 // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward. | |
| 273 // Returns false if back_deltas was not set, due to there being no point in | |
| 274 // propagating further backwards. Thus most complete networks will always | |
| 275 // return false from Backward! | |
| 276 virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, | |
| 277 NetworkScratch *scratch, NetworkIO *back_deltas) = 0; | |
| 278 | |
| 279 // === Debug image display methods. === | |
| 280 // Displays the image of the matrix to the forward window. | |
| 281 void DisplayForward(const NetworkIO &matrix); | |
| 282 // Displays the image of the matrix to the backward window. | |
| 283 void DisplayBackward(const NetworkIO &matrix); | |
| 284 | |
| 285 // Creates the window if needed, otherwise clears it. | |
| 286 static void ClearWindow(bool tess_coords, const char *window_name, int width, | |
| 287 int height, ScrollView **window); | |
| 288 | |
| 289 // Displays the pix in the given window. and returns the height of the pix. | |
| 290 // The pix is pixDestroyed. | |
| 291 static int DisplayImage(Image pix, ScrollView *window); | |
| 292 | |
| 293 protected: | |
| 294 // Returns a random number in [-range, range]. | |
| 295 TFloat Random(TFloat range); | |
| 296 | |
| 297 protected: | |
| 298 NetworkType type_; // Type of the derived network class. | |
| 299 TrainingState training_; // Are we currently training? | |
| 300 bool needs_to_backprop_; // This network needs to output back_deltas. | |
| 301 int32_t network_flags_; // Behavior control flags in NetworkFlags. | |
| 302 int32_t ni_; // Number of input values. | |
| 303 int32_t no_; // Number of output values. | |
| 304 int32_t num_weights_; // Number of weights in this and sub-network. | |
| 305 std::string name_; // A unique name for this layer. | |
| 306 | |
| 307 // NOT-serialized debug data. | |
| 308 ScrollView *forward_win_; // Recognition debug display window. | |
| 309 ScrollView *backward_win_; // Training debug display window. | |
| 310 TRand *randomizer_; // Random number generator. | |
| 311 }; | |
| 312 | |
| 313 } // namespace tesseract. | |
| 314 | |
| 315 #endif // TESSERACT_LSTM_NETWORK_H_ |
