comparison mupdf-source/thirdparty/tesseract/src/lstm/network.h @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: network.h
3 // Description: Base class for neural network implementations.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #ifndef TESSERACT_LSTM_NETWORK_H_
19 #define TESSERACT_LSTM_NETWORK_H_
20
21 #include "helpers.h"
22 #include "matrix.h"
23 #include "networkio.h"
24 #include "serialis.h"
25 #include "static_shape.h"
26 #include "tprintf.h"
27
28 #include <cmath>
29 #include <cstdio>
30
31 struct Pix;
32
33 namespace tesseract {
34
35 class ScrollView;
36 class TBOX;
37 class ImageData;
38 class NetworkScratch;
39
40 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
41 enum NetworkType {
42 NT_NONE, // The naked base class.
43 NT_INPUT, // Inputs from an image.
44 // Plumbing networks combine other networks or rearrange the inputs.
45 NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
46 NT_MAXPOOL, // Chooses the max result from a rectangle.
47 NT_PARALLEL, // Runs networks in parallel.
48 NT_REPLICATED, // Runs identical networks in parallel.
49 NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
50 NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
51 NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
52 NT_SERIES, // Executes a sequence of layers.
53 NT_RECONFIG, // Scales the time/y size but makes the output deeper.
54 NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
55 NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
56 NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
57 // Functional networks actually calculate stuff.
58 NT_LSTM, // Long-Short-Term-Memory block.
59 NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
60 NT_LOGISTIC, // Fully connected logistic nonlinearity.
61 NT_POSCLIP, // Fully connected rect lin version of logistic.
62 NT_SYMCLIP, // Fully connected rect lin version of tanh.
63 NT_TANH, // Fully connected with tanh nonlinearity.
64 NT_RELU, // Fully connected with rectifier nonlinearity.
65 NT_LINEAR, // Fully connected with no nonlinearity.
66 NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
67 NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
68 // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
69 // the outputs fed back to the input of the LSTM at the next timestep.
70 // The ENCODED version binary encodes the softmax outputs, providing log2 of
71 // the number of outputs as additional inputs, and the other version just
72 // provides all the softmax outputs as additional inputs.
73 NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
74 NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
75 // A TensorFlow graph encapsulated as a Tesseract network.
76 NT_TENSORFLOW,
77
78 NT_COUNT // Array size.
79 };
80
81 // Enum of Network behavior flags. Can in theory be set for each individual
82 // network element.
83 enum NetworkFlags {
84 // Network forward/backprop behavior.
85 NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
86 NF_ADAM = 128, // Weight-specific learning rate.
87 };
88
89 // State of training and desired state used in SetEnableTraining.
90 enum TrainingState {
91 // Valid states of training_.
92 TS_DISABLED, // Disabled permanently.
93 TS_ENABLED, // Enabled for backprop and to write a training dump.
94 // Re-enable from ANY disabled state.
95 TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
96 // Valid only for SetEnableTraining.
97 TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
98 };
99
100 // Base class for network types. Not quite an abstract base class, but almost.
101 // Most of the time no isolated Network exists, except prior to
102 // deserialization.
103 class TESS_API Network {
104 public:
105 Network();
106 Network(NetworkType type, const std::string &name, int ni, int no);
107 virtual ~Network() = default;
108
109 // Accessors.
110 NetworkType type() const {
111 return type_;
112 }
113 bool IsTraining() const {
114 return training_ == TS_ENABLED;
115 }
116 bool needs_to_backprop() const {
117 return needs_to_backprop_;
118 }
119 int num_weights() const {
120 return num_weights_;
121 }
122 int NumInputs() const {
123 return ni_;
124 }
125 int NumOutputs() const {
126 return no_;
127 }
128 // Returns the required shape input to the network.
129 virtual StaticShape InputShape() const {
130 StaticShape result;
131 return result;
132 }
133 // Returns the shape output from the network given an input shape (which may
134 // be partially unknown ie zero).
135 virtual StaticShape OutputShape(const StaticShape &input_shape) const {
136 StaticShape result(input_shape);
137 result.set_depth(no_);
138 return result;
139 }
140 const std::string &name() const {
141 return name_;
142 }
143 virtual std::string spec() const = 0;
144 bool TestFlag(NetworkFlags flag) const {
145 return (network_flags_ & flag) != 0;
146 }
147
148 // Initialization and administrative functions that are mostly provided
149 // by Plumbing.
150 // Returns true if the given type is derived from Plumbing, and thus contains
151 // multiple sub-networks that can have their own learning rate.
152 virtual bool IsPlumbingType() const {
153 return false;
154 }
155
156 // Suspends/Enables/Permanently disables training by setting the training_
157 // flag. Serialize and DeSerialize only operate on the run-time data if state
158 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
159 // temporarily disable layers in state TS_ENABLED, allowing a trainer to
160 // serialize as if it were a recognizer.
161 // TS_RE_ENABLE will re-enable layers that were previously in any disabled
162 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
163 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
164 // recognizer can be converted back to a trainer.
165 virtual void SetEnableTraining(TrainingState state);
166
167 // Sets flags that control the action of the network. See NetworkFlags enum
168 // for bit values.
169 virtual void SetNetworkFlags(uint32_t flags);
170
171 // Sets up the network for training. Initializes weights using weights of
172 // scale `range` picked according to the random number generator `randomizer`.
173 // Note that randomizer is a borrowed pointer that should outlive the network
174 // and should not be deleted by any of the networks.
175 // Returns the number of weights initialized.
176 virtual int InitWeights(float range, TRand *randomizer);
177 // Changes the number of outputs to the outside world to the size of the given
178 // code_map. Recursively searches the entire network for Softmax layers that
179 // have exactly old_no outputs, and operates only on those, leaving all others
180 // unchanged. This enables networks with multiple output layers to get all
181 // their softmaxes updated, but if an internal layer, uses one of those
182 // softmaxes for input, then the inputs will effectively be scrambled.
183 // TODO(rays) Fix this before any such network is implemented.
184 // The softmaxes are resized by copying the old weight matrix entries for each
185 // output from code_map[output] where non-negative, and uses the mean (over
186 // all outputs) of the existing weights for all outputs with negative code_map
187 // entries. Returns the new number of weights.
188 virtual int RemapOutputs([[maybe_unused]] int old_no,
189 [[maybe_unused]] const std::vector<int> &code_map) {
190 return 0;
191 }
192
193 // Converts a float network to an int network.
194 virtual void ConvertToInt() {}
195
196 // Provides a pointer to a TRand for any networks that care to use it.
197 // Note that randomizer is a borrowed pointer that should outlive the network
198 // and should not be deleted by any of the networks.
199 virtual void SetRandomizer(TRand *randomizer);
200
201 // Sets needs_to_backprop_ to needs_backprop and returns true if
202 // needs_backprop || any weights in this network so the next layer forward
203 // can be told to produce backprop for this layer if needed.
204 virtual bool SetupNeedsBackprop(bool needs_backprop);
205
206 // Returns the most recent reduction factor that the network applied to the
207 // time sequence. Assumes that any 2-d is already eliminated. Used for
208 // scaling bounding boxes of truth data and calculating result bounding boxes.
209 // WARNING: if GlobalMinimax is used to vary the scale, this will return
210 // the last used scale factor. Call it before any forward, and it will return
211 // the minimum scale factor of the paths through the GlobalMinimax.
212 virtual int XScaleFactor() const {
213 return 1;
214 }
215
216 // Provides the (minimum) x scale factor to the network (of interest only to
217 // input units) so they can determine how to scale bounding boxes.
218 virtual void CacheXScaleFactor([[maybe_unused]] int factor) {}
219
220 // Provides debug output on the weights.
221 virtual void DebugWeights() = 0;
222
223 // Writes to the given file. Returns false in case of error.
224 // Should be overridden by subclasses, but called by their Serialize.
225 virtual bool Serialize(TFile *fp) const;
226 // Reads from the given file. Returns false in case of error.
227 // Should be overridden by subclasses, but NOT called by their DeSerialize.
228 virtual bool DeSerialize(TFile *fp) = 0;
229
230 public:
231 // Updates the weights using the given learning rate, momentum and adam_beta.
232 // num_samples is used in the adam computation iff use_adam_ is true.
233 virtual void Update([[maybe_unused]] float learning_rate,
234 [[maybe_unused]] float momentum,
235 [[maybe_unused]] float adam_beta,
236 [[maybe_unused]] int num_samples) {}
237 // Sums the products of weight updates in *this and other, splitting into
238 // positive (same direction) in *same and negative (different direction) in
239 // *changed.
240 virtual void CountAlternators([[maybe_unused]] const Network &other,
241 [[maybe_unused]] TFloat *same,
242 [[maybe_unused]] TFloat *changed) const {}
243
244 // Reads from the given file. Returns nullptr in case of error.
245 // Determines the type of the serialized class and calls its DeSerialize
246 // on a new object of the appropriate type, which is returned.
247 static Network *CreateFromFile(TFile *fp);
248
249 // Runs forward propagation of activations on the input line.
250 // Note that input and output are both 2-d arrays.
251 // The 1st index is the time element. In a 1-d network, it might be the pixel
252 // position on the textline. In a 2-d network, the linearization is defined
253 // by the stride_map. (See networkio.h).
254 // The 2nd index of input is the network inputs/outputs, and the dimension
255 // of the input must match NumInputs() of this network.
256 // The output array will be resized as needed so that its 1st dimension is
257 // always equal to the number of output values, and its second dimension is
258 // always NumOutputs(). Note that all this detail is encapsulated away inside
259 // NetworkIO, as are the internals of the scratch memory space used by the
260 // network. See networkscratch.h for that.
261 // If input_transpose is not nullptr, then it contains the transpose of input,
262 // and the caller guarantees that it will still be valid on the next call to
263 // backward. The callee is therefore at liberty to save the pointer and
264 // reference it on a call to backward. This is a bit ugly, but it makes it
265 // possible for a replicating parallel to calculate the input transpose once
266 // instead of all the replicated networks having to do it.
267 virtual void Forward(bool debug, const NetworkIO &input,
268 const TransposedArray *input_transpose,
269 NetworkScratch *scratch, NetworkIO *output) = 0;
270
271 // Runs backward propagation of errors on fwdX_deltas.
272 // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
273 // Returns false if back_deltas was not set, due to there being no point in
274 // propagating further backwards. Thus most complete networks will always
275 // return false from Backward!
276 virtual bool Backward(bool debug, const NetworkIO &fwd_deltas,
277 NetworkScratch *scratch, NetworkIO *back_deltas) = 0;
278
279 // === Debug image display methods. ===
280 // Displays the image of the matrix to the forward window.
281 void DisplayForward(const NetworkIO &matrix);
282 // Displays the image of the matrix to the backward window.
283 void DisplayBackward(const NetworkIO &matrix);
284
285 // Creates the window if needed, otherwise clears it.
286 static void ClearWindow(bool tess_coords, const char *window_name, int width,
287 int height, ScrollView **window);
288
289 // Displays the pix in the given window. and returns the height of the pix.
290 // The pix is pixDestroyed.
291 static int DisplayImage(Image pix, ScrollView *window);
292
293 protected:
294 // Returns a random number in [-range, range].
295 TFloat Random(TFloat range);
296
297 protected:
298 NetworkType type_; // Type of the derived network class.
299 TrainingState training_; // Are we currently training?
300 bool needs_to_backprop_; // This network needs to output back_deltas.
301 int32_t network_flags_; // Behavior control flags in NetworkFlags.
302 int32_t ni_; // Number of input values.
303 int32_t no_; // Number of output values.
304 int32_t num_weights_; // Number of weights in this and sub-network.
305 std::string name_; // A unique name for this layer.
306
307 // NOT-serialized debug data.
308 ScrollView *forward_win_; // Recognition debug display window.
309 ScrollView *backward_win_; // Training debug display window.
310 TRand *randomizer_; // Random number generator.
311 };
312
313 } // namespace tesseract.
314
315 #endif // TESSERACT_LSTM_NETWORK_H_