Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/network.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: network.cpp | |
| 3 // Description: Base class for neural network implementations. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 // Include automatically generated configuration file if running autoconf. | |
| 19 #ifdef HAVE_CONFIG_H | |
| 20 # include "config_auto.h" | |
| 21 #endif | |
| 22 | |
| 23 #include "network.h" | |
| 24 | |
| 25 #include <cstdlib> | |
| 26 | |
| 27 // This base class needs to know about all its sub-classes because of the | |
| 28 // factory deserializing method: CreateFromFile. | |
| 29 #include <allheaders.h> | |
| 30 #include "convolve.h" | |
| 31 #include "fullyconnected.h" | |
| 32 #include "input.h" | |
| 33 #include "lstm.h" | |
| 34 #include "maxpool.h" | |
| 35 #include "parallel.h" | |
| 36 #include "reconfig.h" | |
| 37 #include "reversed.h" | |
| 38 #include "scrollview.h" | |
| 39 #include "series.h" | |
| 40 #include "statistc.h" | |
| 41 #include "tprintf.h" | |
| 42 | |
| 43 namespace tesseract { | |
| 44 | |
| 45 #ifndef GRAPHICS_DISABLED | |
| 46 | |
| 47 // Min and max window sizes. | |
| 48 const int kMinWinSize = 500; | |
| 49 const int kMaxWinSize = 2000; | |
| 50 // Window frame sizes need adding on to make the content fit. | |
| 51 const int kXWinFrameSize = 30; | |
| 52 const int kYWinFrameSize = 80; | |
| 53 | |
| 54 #endif // !GRAPHICS_DISABLED | |
| 55 | |
| 56 // String names corresponding to the NetworkType enum. | |
| 57 // Keep in sync with NetworkType. | |
| 58 // Names used in Serialization to allow re-ordering/addition/deletion of | |
| 59 // layer types in NetworkType without invalidating existing network files. | |
| 60 static char const *const kTypeNames[NT_COUNT] = { | |
| 61 "Invalid", "Input", | |
| 62 "Convolve", "Maxpool", | |
| 63 "Parallel", "Replicated", | |
| 64 "ParBidiLSTM", "DepParUDLSTM", | |
| 65 "Par2dLSTM", "Series", | |
| 66 "Reconfig", "RTLReversed", | |
| 67 "TTBReversed", "XYTranspose", | |
| 68 "LSTM", "SummLSTM", | |
| 69 "Logistic", "LinLogistic", | |
| 70 "LinTanh", "Tanh", | |
| 71 "Relu", "Linear", | |
| 72 "Softmax", "SoftmaxNoCTC", | |
| 73 "LSTMSoftmax", "LSTMBinarySoftmax", | |
| 74 "TensorFlow", | |
| 75 }; | |
| 76 | |
| 77 Network::Network() | |
| 78 : type_(NT_NONE) | |
| 79 , training_(TS_ENABLED) | |
| 80 , needs_to_backprop_(true) | |
| 81 , network_flags_(0) | |
| 82 , ni_(0) | |
| 83 , no_(0) | |
| 84 , num_weights_(0) | |
| 85 , forward_win_(nullptr) | |
| 86 , backward_win_(nullptr) | |
| 87 , randomizer_(nullptr) {} | |
| 88 Network::Network(NetworkType type, const std::string &name, int ni, int no) | |
| 89 : type_(type) | |
| 90 , training_(TS_ENABLED) | |
| 91 , needs_to_backprop_(true) | |
| 92 , network_flags_(0) | |
| 93 , ni_(ni) | |
| 94 , no_(no) | |
| 95 , num_weights_(0) | |
| 96 , name_(name) | |
| 97 , forward_win_(nullptr) | |
| 98 , backward_win_(nullptr) | |
| 99 , randomizer_(nullptr) {} | |
| 100 | |
| 101 // Suspends/Enables/Permanently disables training by setting the training_ | |
| 102 // flag. Serialize and DeSerialize only operate on the run-time data if state | |
| 103 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will | |
| 104 // temporarily disable layers in state TS_ENABLED, allowing a trainer to | |
| 105 // serialize as if it were a recognizer. | |
| 106 // TS_RE_ENABLE will re-enable layers that were previously in any disabled | |
| 107 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in | |
| 108 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a | |
| 109 // recognizer can be converted back to a trainer. | |
| 110 void Network::SetEnableTraining(TrainingState state) { | |
| 111 if (state == TS_RE_ENABLE) { | |
| 112 // Enable only from temp disabled. | |
| 113 if (training_ == TS_TEMP_DISABLE) { | |
| 114 training_ = TS_ENABLED; | |
| 115 } | |
| 116 } else if (state == TS_TEMP_DISABLE) { | |
| 117 // Temp disable only from enabled. | |
| 118 if (training_ == TS_ENABLED) { | |
| 119 training_ = state; | |
| 120 } | |
| 121 } else { | |
| 122 training_ = state; | |
| 123 } | |
| 124 } | |
| 125 | |
| 126 // Sets flags that control the action of the network. See NetworkFlags enum | |
| 127 // for bit values. | |
| 128 void Network::SetNetworkFlags(uint32_t flags) { | |
| 129 network_flags_ = flags; | |
| 130 } | |
| 131 | |
| 132 // Sets up the network for training. Initializes weights using weights of | |
| 133 // scale `range` picked according to the random number generator `randomizer`. | |
| 134 int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) { | |
| 135 randomizer_ = randomizer; | |
| 136 return 0; | |
| 137 } | |
| 138 | |
| 139 // Provides a pointer to a TRand for any networks that care to use it. | |
| 140 // Note that randomizer is a borrowed pointer that should outlive the network | |
| 141 // and should not be deleted by any of the networks. | |
| 142 void Network::SetRandomizer(TRand *randomizer) { | |
| 143 randomizer_ = randomizer; | |
| 144 } | |
| 145 | |
| 146 // Sets needs_to_backprop_ to needs_backprop and returns true if | |
| 147 // needs_backprop || any weights in this network so the next layer forward | |
| 148 // can be told to produce backprop for this layer if needed. | |
| 149 bool Network::SetupNeedsBackprop(bool needs_backprop) { | |
| 150 needs_to_backprop_ = needs_backprop; | |
| 151 return needs_backprop || num_weights_ > 0; | |
| 152 } | |
| 153 | |
| 154 // Writes to the given file. Returns false in case of error. | |
| 155 bool Network::Serialize(TFile *fp) const { | |
| 156 int8_t data = NT_NONE; | |
| 157 if (!fp->Serialize(&data)) { | |
| 158 return false; | |
| 159 } | |
| 160 std::string type_name = kTypeNames[type_]; | |
| 161 if (!fp->Serialize(type_name)) { | |
| 162 return false; | |
| 163 } | |
| 164 data = training_; | |
| 165 if (!fp->Serialize(&data)) { | |
| 166 return false; | |
| 167 } | |
| 168 data = needs_to_backprop_; | |
| 169 if (!fp->Serialize(&data)) { | |
| 170 return false; | |
| 171 } | |
| 172 if (!fp->Serialize(&network_flags_)) { | |
| 173 return false; | |
| 174 } | |
| 175 if (!fp->Serialize(&ni_)) { | |
| 176 return false; | |
| 177 } | |
| 178 if (!fp->Serialize(&no_)) { | |
| 179 return false; | |
| 180 } | |
| 181 if (!fp->Serialize(&num_weights_)) { | |
| 182 return false; | |
| 183 } | |
| 184 uint32_t length = name_.length(); | |
| 185 if (!fp->Serialize(&length)) { | |
| 186 return false; | |
| 187 } | |
| 188 return fp->Serialize(name_.c_str(), length); | |
| 189 } | |
| 190 | |
| 191 static NetworkType getNetworkType(TFile *fp) { | |
| 192 int8_t data; | |
| 193 if (!fp->DeSerialize(&data)) { | |
| 194 return NT_NONE; | |
| 195 } | |
| 196 if (data == NT_NONE) { | |
| 197 std::string type_name; | |
| 198 if (!fp->DeSerialize(type_name)) { | |
| 199 return NT_NONE; | |
| 200 } | |
| 201 for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) { | |
| 202 } | |
| 203 if (data == NT_COUNT) { | |
| 204 tprintf("Invalid network layer type:%s\n", type_name.c_str()); | |
| 205 return NT_NONE; | |
| 206 } | |
| 207 } | |
| 208 return static_cast<NetworkType>(data); | |
| 209 } | |
| 210 | |
| 211 // Reads from the given file. Returns nullptr in case of error. | |
| 212 // Determines the type of the serialized class and calls its DeSerialize | |
| 213 // on a new object of the appropriate type, which is returned. | |
| 214 Network *Network::CreateFromFile(TFile *fp) { | |
| 215 NetworkType type; // Type of the derived network class. | |
| 216 TrainingState training; // Are we currently training? | |
| 217 bool needs_to_backprop; // This network needs to output back_deltas. | |
| 218 int32_t network_flags; // Behavior control flags in NetworkFlags. | |
| 219 int32_t ni; // Number of input values. | |
| 220 int32_t no; // Number of output values. | |
| 221 int32_t num_weights; // Number of weights in this and sub-network. | |
| 222 std::string name; // A unique name for this layer. | |
| 223 int8_t data; | |
| 224 Network *network = nullptr; | |
| 225 type = getNetworkType(fp); | |
| 226 if (!fp->DeSerialize(&data)) { | |
| 227 return nullptr; | |
| 228 } | |
| 229 training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED; | |
| 230 if (!fp->DeSerialize(&data)) { | |
| 231 return nullptr; | |
| 232 } | |
| 233 needs_to_backprop = data != 0; | |
| 234 if (!fp->DeSerialize(&network_flags)) { | |
| 235 return nullptr; | |
| 236 } | |
| 237 if (!fp->DeSerialize(&ni)) { | |
| 238 return nullptr; | |
| 239 } | |
| 240 if (!fp->DeSerialize(&no)) { | |
| 241 return nullptr; | |
| 242 } | |
| 243 if (!fp->DeSerialize(&num_weights)) { | |
| 244 return nullptr; | |
| 245 } | |
| 246 if (!fp->DeSerialize(name)) { | |
| 247 return nullptr; | |
| 248 } | |
| 249 | |
| 250 switch (type) { | |
| 251 case NT_CONVOLVE: | |
| 252 network = new Convolve(name, ni, 0, 0); | |
| 253 break; | |
| 254 case NT_INPUT: | |
| 255 network = new Input(name, ni, no); | |
| 256 break; | |
| 257 case NT_LSTM: | |
| 258 case NT_LSTM_SOFTMAX: | |
| 259 case NT_LSTM_SOFTMAX_ENCODED: | |
| 260 case NT_LSTM_SUMMARY: | |
| 261 network = new LSTM(name, ni, no, no, false, type); | |
| 262 break; | |
| 263 case NT_MAXPOOL: | |
| 264 network = new Maxpool(name, ni, 0, 0); | |
| 265 break; | |
| 266 // All variants of Parallel. | |
| 267 case NT_PARALLEL: | |
| 268 case NT_REPLICATED: | |
| 269 case NT_PAR_RL_LSTM: | |
| 270 case NT_PAR_UD_LSTM: | |
| 271 case NT_PAR_2D_LSTM: | |
| 272 network = new Parallel(name, type); | |
| 273 break; | |
| 274 case NT_RECONFIG: | |
| 275 network = new Reconfig(name, ni, 0, 0); | |
| 276 break; | |
| 277 // All variants of reversed. | |
| 278 case NT_XREVERSED: | |
| 279 case NT_YREVERSED: | |
| 280 case NT_XYTRANSPOSE: | |
| 281 network = new Reversed(name, type); | |
| 282 break; | |
| 283 case NT_SERIES: | |
| 284 network = new Series(name); | |
| 285 break; | |
| 286 case NT_TENSORFLOW: | |
| 287 tprintf("Unsupported TensorFlow model\n"); | |
| 288 break; | |
| 289 // All variants of FullyConnected. | |
| 290 case NT_SOFTMAX: | |
| 291 case NT_SOFTMAX_NO_CTC: | |
| 292 case NT_RELU: | |
| 293 case NT_TANH: | |
| 294 case NT_LINEAR: | |
| 295 case NT_LOGISTIC: | |
| 296 case NT_POSCLIP: | |
| 297 case NT_SYMCLIP: | |
| 298 network = new FullyConnected(name, ni, no, type); | |
| 299 break; | |
| 300 default: | |
| 301 break; | |
| 302 } | |
| 303 if (network) { | |
| 304 network->training_ = training; | |
| 305 network->needs_to_backprop_ = needs_to_backprop; | |
| 306 network->network_flags_ = network_flags; | |
| 307 network->num_weights_ = num_weights; | |
| 308 if (!network->DeSerialize(fp)) { | |
| 309 delete network; | |
| 310 network = nullptr; | |
| 311 } | |
| 312 } | |
| 313 return network; | |
| 314 } | |
| 315 | |
| 316 // Returns a random number in [-range, range]. | |
| 317 TFloat Network::Random(TFloat range) { | |
| 318 ASSERT_HOST(randomizer_ != nullptr); | |
| 319 return randomizer_->SignedRand(range); | |
| 320 } | |
| 321 | |
| 322 #ifndef GRAPHICS_DISABLED | |
| 323 | |
| 324 // === Debug image display methods. === | |
| 325 // Displays the image of the matrix to the forward window. | |
| 326 void Network::DisplayForward(const NetworkIO &matrix) { | |
| 327 Image image = matrix.ToPix(); | |
| 328 ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_); | |
| 329 DisplayImage(image, forward_win_); | |
| 330 forward_win_->Update(); | |
| 331 } | |
| 332 | |
| 333 // Displays the image of the matrix to the backward window. | |
| 334 void Network::DisplayBackward(const NetworkIO &matrix) { | |
| 335 Image image = matrix.ToPix(); | |
| 336 std::string window_name = name_ + "-back"; | |
| 337 ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_); | |
| 338 DisplayImage(image, backward_win_); | |
| 339 backward_win_->Update(); | |
| 340 } | |
| 341 | |
| 342 // Creates the window if needed, otherwise clears it. | |
| 343 void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height, | |
| 344 ScrollView **window) { | |
| 345 if (*window == nullptr) { | |
| 346 int min_size = std::min(width, height); | |
| 347 if (min_size < kMinWinSize) { | |
| 348 if (min_size < 1) { | |
| 349 min_size = 1; | |
| 350 } | |
| 351 width = width * kMinWinSize / min_size; | |
| 352 height = height * kMinWinSize / min_size; | |
| 353 } | |
| 354 width += kXWinFrameSize; | |
| 355 height += kYWinFrameSize; | |
| 356 if (width > kMaxWinSize) { | |
| 357 width = kMaxWinSize; | |
| 358 } | |
| 359 if (height > kMaxWinSize) { | |
| 360 height = kMaxWinSize; | |
| 361 } | |
| 362 *window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords); | |
| 363 tprintf("Created window %s of size %d, %d\n", window_name, width, height); | |
| 364 } else { | |
| 365 (*window)->Clear(); | |
| 366 } | |
| 367 } | |
| 368 | |
| 369 // Displays the pix in the given window. and returns the height of the pix. | |
| 370 // The pix is pixDestroyed. | |
| 371 int Network::DisplayImage(Image pix, ScrollView *window) { | |
| 372 int height = pixGetHeight(pix); | |
| 373 window->Draw(pix, 0, 0); | |
| 374 pix.destroy(); | |
| 375 return height; | |
| 376 } | |
| 377 #endif // !GRAPHICS_DISABLED | |
| 378 | |
| 379 } // namespace tesseract. |
