diff mupdf-source/thirdparty/tesseract/src/lstm/network.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mupdf-source/thirdparty/tesseract/src/lstm/network.cpp	Mon Sep 15 11:43:07 2025 +0200
@@ -0,0 +1,379 @@
+///////////////////////////////////////////////////////////////////////
+// File:        network.cpp
+// Description: Base class for neural network implementations.
+// Author:      Ray Smith
+//
+// (C) Copyright 2013, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+///////////////////////////////////////////////////////////////////////
+
+// Include automatically generated configuration file if running autoconf.
+#ifdef HAVE_CONFIG_H
+#  include "config_auto.h"
+#endif
+
+#include "network.h"
+
+#include <cstdlib>
+
+// This base class needs to know about all its sub-classes because of the
+// factory deserializing method: CreateFromFile.
+#include <allheaders.h>
+#include "convolve.h"
+#include "fullyconnected.h"
+#include "input.h"
+#include "lstm.h"
+#include "maxpool.h"
+#include "parallel.h"
+#include "reconfig.h"
+#include "reversed.h"
+#include "scrollview.h"
+#include "series.h"
+#include "statistc.h"
+#include "tprintf.h"
+
+namespace tesseract {
+
+#ifndef GRAPHICS_DISABLED
+
+// Min and max window sizes.
+const int kMinWinSize = 500;
+const int kMaxWinSize = 2000;
+// Window frame sizes need adding on to make the content fit.
+const int kXWinFrameSize = 30;
+const int kYWinFrameSize = 80;
+
+#endif // !GRAPHICS_DISABLED
+
+// String names corresponding to the NetworkType enum.
+// Keep in sync with NetworkType.
+// Names used in Serialization to allow re-ordering/addition/deletion of
+// layer types in NetworkType without invalidating existing network files.
+static char const *const kTypeNames[NT_COUNT] = {
+    "Invalid",     "Input",
+    "Convolve",    "Maxpool",
+    "Parallel",    "Replicated",
+    "ParBidiLSTM", "DepParUDLSTM",
+    "Par2dLSTM",   "Series",
+    "Reconfig",    "RTLReversed",
+    "TTBReversed", "XYTranspose",
+    "LSTM",        "SummLSTM",
+    "Logistic",    "LinLogistic",
+    "LinTanh",     "Tanh",
+    "Relu",        "Linear",
+    "Softmax",     "SoftmaxNoCTC",
+    "LSTMSoftmax", "LSTMBinarySoftmax",
+    "TensorFlow",
+};
+
+Network::Network()
+    : type_(NT_NONE)
+    , training_(TS_ENABLED)
+    , needs_to_backprop_(true)
+    , network_flags_(0)
+    , ni_(0)
+    , no_(0)
+    , num_weights_(0)
+    , forward_win_(nullptr)
+    , backward_win_(nullptr)
+    , randomizer_(nullptr) {}
+Network::Network(NetworkType type, const std::string &name, int ni, int no)
+    : type_(type)
+    , training_(TS_ENABLED)
+    , needs_to_backprop_(true)
+    , network_flags_(0)
+    , ni_(ni)
+    , no_(no)
+    , num_weights_(0)
+    , name_(name)
+    , forward_win_(nullptr)
+    , backward_win_(nullptr)
+    , randomizer_(nullptr) {}
+
+// Suspends/Enables/Permanently disables training by setting the training_
+// flag. Serialize and DeSerialize only operate on the run-time data if state
+// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
+// temporarily disable layers in state TS_ENABLED, allowing a trainer to
+// serialize as if it were a recognizer.
+// TS_RE_ENABLE will re-enable layers that were previously in any disabled
+// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
+// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
+// recognizer can be converted back to a trainer.
+void Network::SetEnableTraining(TrainingState state) {
+  if (state == TS_RE_ENABLE) {
+    // Enable only from temp disabled.
+    if (training_ == TS_TEMP_DISABLE) {
+      training_ = TS_ENABLED;
+    }
+  } else if (state == TS_TEMP_DISABLE) {
+    // Temp disable only from enabled.
+    if (training_ == TS_ENABLED) {
+      training_ = state;
+    }
+  } else {
+    training_ = state;
+  }
+}
+
+// Sets flags that control the action of the network. See NetworkFlags enum
+// for bit values.
+void Network::SetNetworkFlags(uint32_t flags) {
+  network_flags_ = flags;
+}
+
+// Sets up the network for training. Initializes weights using weights of
+// scale `range` picked according to the random number generator `randomizer`.
+int Network::InitWeights([[maybe_unused]] float range, TRand *randomizer) {
+  randomizer_ = randomizer;
+  return 0;
+}
+
+// Provides a pointer to a TRand for any networks that care to use it.
+// Note that randomizer is a borrowed pointer that should outlive the network
+// and should not be deleted by any of the networks.
+void Network::SetRandomizer(TRand *randomizer) {
+  randomizer_ = randomizer;
+}
+
+// Sets needs_to_backprop_ to needs_backprop and returns true if
+// needs_backprop || any weights in this network so the next layer forward
+// can be told to produce backprop for this layer if needed.
+bool Network::SetupNeedsBackprop(bool needs_backprop) {
+  needs_to_backprop_ = needs_backprop;
+  return needs_backprop || num_weights_ > 0;
+}
+
+// Writes to the given file. Returns false in case of error.
+bool Network::Serialize(TFile *fp) const {
+  int8_t data = NT_NONE;
+  if (!fp->Serialize(&data)) {
+    return false;
+  }
+  std::string type_name = kTypeNames[type_];
+  if (!fp->Serialize(type_name)) {
+    return false;
+  }
+  data = training_;
+  if (!fp->Serialize(&data)) {
+    return false;
+  }
+  data = needs_to_backprop_;
+  if (!fp->Serialize(&data)) {
+    return false;
+  }
+  if (!fp->Serialize(&network_flags_)) {
+    return false;
+  }
+  if (!fp->Serialize(&ni_)) {
+    return false;
+  }
+  if (!fp->Serialize(&no_)) {
+    return false;
+  }
+  if (!fp->Serialize(&num_weights_)) {
+    return false;
+  }
+  uint32_t length = name_.length();
+  if (!fp->Serialize(&length)) {
+    return false;
+  }
+  return fp->Serialize(name_.c_str(), length);
+}
+
+static NetworkType getNetworkType(TFile *fp) {
+  int8_t data;
+  if (!fp->DeSerialize(&data)) {
+    return NT_NONE;
+  }
+  if (data == NT_NONE) {
+    std::string type_name;
+    if (!fp->DeSerialize(type_name)) {
+      return NT_NONE;
+    }
+    for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
+    }
+    if (data == NT_COUNT) {
+      tprintf("Invalid network layer type:%s\n", type_name.c_str());
+      return NT_NONE;
+    }
+  }
+  return static_cast<NetworkType>(data);
+}
+
+// Reads from the given file. Returns nullptr in case of error.
+// Determines the type of the serialized class and calls its DeSerialize
+// on a new object of the appropriate type, which is returned.
+Network *Network::CreateFromFile(TFile *fp) {
+  NetworkType type;       // Type of the derived network class.
+  TrainingState training; // Are we currently training?
+  bool needs_to_backprop; // This network needs to output back_deltas.
+  int32_t network_flags;  // Behavior control flags in NetworkFlags.
+  int32_t ni;             // Number of input values.
+  int32_t no;             // Number of output values.
+  int32_t num_weights;    // Number of weights in this and sub-network.
+  std::string name;       // A unique name for this layer.
+  int8_t data;
+  Network *network = nullptr;
+  type = getNetworkType(fp);
+  if (!fp->DeSerialize(&data)) {
+    return nullptr;
+  }
+  training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
+  if (!fp->DeSerialize(&data)) {
+    return nullptr;
+  }
+  needs_to_backprop = data != 0;
+  if (!fp->DeSerialize(&network_flags)) {
+    return nullptr;
+  }
+  if (!fp->DeSerialize(&ni)) {
+    return nullptr;
+  }
+  if (!fp->DeSerialize(&no)) {
+    return nullptr;
+  }
+  if (!fp->DeSerialize(&num_weights)) {
+    return nullptr;
+  }
+  if (!fp->DeSerialize(name)) {
+    return nullptr;
+  }
+
+  switch (type) {
+    case NT_CONVOLVE:
+      network = new Convolve(name, ni, 0, 0);
+      break;
+    case NT_INPUT:
+      network = new Input(name, ni, no);
+      break;
+    case NT_LSTM:
+    case NT_LSTM_SOFTMAX:
+    case NT_LSTM_SOFTMAX_ENCODED:
+    case NT_LSTM_SUMMARY:
+      network = new LSTM(name, ni, no, no, false, type);
+      break;
+    case NT_MAXPOOL:
+      network = new Maxpool(name, ni, 0, 0);
+      break;
+    // All variants of Parallel.
+    case NT_PARALLEL:
+    case NT_REPLICATED:
+    case NT_PAR_RL_LSTM:
+    case NT_PAR_UD_LSTM:
+    case NT_PAR_2D_LSTM:
+      network = new Parallel(name, type);
+      break;
+    case NT_RECONFIG:
+      network = new Reconfig(name, ni, 0, 0);
+      break;
+    // All variants of reversed.
+    case NT_XREVERSED:
+    case NT_YREVERSED:
+    case NT_XYTRANSPOSE:
+      network = new Reversed(name, type);
+      break;
+    case NT_SERIES:
+      network = new Series(name);
+      break;
+    case NT_TENSORFLOW:
+      tprintf("Unsupported TensorFlow model\n");
+      break;
+    // All variants of FullyConnected.
+    case NT_SOFTMAX:
+    case NT_SOFTMAX_NO_CTC:
+    case NT_RELU:
+    case NT_TANH:
+    case NT_LINEAR:
+    case NT_LOGISTIC:
+    case NT_POSCLIP:
+    case NT_SYMCLIP:
+      network = new FullyConnected(name, ni, no, type);
+      break;
+    default:
+      break;
+  }
+  if (network) {
+    network->training_ = training;
+    network->needs_to_backprop_ = needs_to_backprop;
+    network->network_flags_ = network_flags;
+    network->num_weights_ = num_weights;
+    if (!network->DeSerialize(fp)) {
+      delete network;
+      network = nullptr;
+    }
+  }
+  return network;
+}
+
+// Returns a random number in [-range, range].
+TFloat Network::Random(TFloat range) {
+  ASSERT_HOST(randomizer_ != nullptr);
+  return randomizer_->SignedRand(range);
+}
+
+#ifndef GRAPHICS_DISABLED
+
+// === Debug image display methods. ===
+// Displays the image of the matrix to the forward window.
+void Network::DisplayForward(const NetworkIO &matrix) {
+  Image image = matrix.ToPix();
+  ClearWindow(false, name_.c_str(), pixGetWidth(image), pixGetHeight(image), &forward_win_);
+  DisplayImage(image, forward_win_);
+  forward_win_->Update();
+}
+
+// Displays the image of the matrix to the backward window.
+void Network::DisplayBackward(const NetworkIO &matrix) {
+  Image image = matrix.ToPix();
+  std::string window_name = name_ + "-back";
+  ClearWindow(false, window_name.c_str(), pixGetWidth(image), pixGetHeight(image), &backward_win_);
+  DisplayImage(image, backward_win_);
+  backward_win_->Update();
+}
+
+// Creates the window if needed, otherwise clears it.
+void Network::ClearWindow(bool tess_coords, const char *window_name, int width, int height,
+                          ScrollView **window) {
+  if (*window == nullptr) {
+    int min_size = std::min(width, height);
+    if (min_size < kMinWinSize) {
+      if (min_size < 1) {
+        min_size = 1;
+      }
+      width = width * kMinWinSize / min_size;
+      height = height * kMinWinSize / min_size;
+    }
+    width += kXWinFrameSize;
+    height += kYWinFrameSize;
+    if (width > kMaxWinSize) {
+      width = kMaxWinSize;
+    }
+    if (height > kMaxWinSize) {
+      height = kMaxWinSize;
+    }
+    *window = new ScrollView(window_name, 80, 100, width, height, width, height, tess_coords);
+    tprintf("Created window %s of size %d, %d\n", window_name, width, height);
+  } else {
+    (*window)->Clear();
+  }
+}
+
+// Displays the pix in the given window. and returns the height of the pix.
+// The pix is pixDestroyed.
+int Network::DisplayImage(Image pix, ScrollView *window) {
+  int height = pixGetHeight(pix);
+  window->Draw(pix, 0, 0);
+  pix.destroy();
+  return height;
+}
+#endif // !GRAPHICS_DISABLED
+
+} // namespace tesseract.