comparison mupdf-source/thirdparty/tesseract/src/training/common/ctc.h @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: ctc.h
3 // Description: Slightly improved standard CTC to compute the targets.
4 // Author: Ray Smith
5 // Created: Wed Jul 13 15:17:06 PDT 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////
18
19 #ifndef TESSERACT_LSTM_CTC_H_
20 #define TESSERACT_LSTM_CTC_H_
21
22 #include "export.h"
23 #include "network.h"
24 #include "networkio.h"
25 #include "scrollview.h"
26
27 namespace tesseract {
28
29 // Class to encapsulate CTC and simple target generation.
30 class TESS_COMMON_TRAINING_API CTC {
31 public:
32 // Normalizes the probabilities such that no target has a prob below min_prob,
33 // and, provided that the initial total is at least min_total_prob, then all
34 // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
35 // probability is thus 1 - (num_classes-1)*min_prob.
36 static void NormalizeProbs(NetworkIO *probs) {
37 NormalizeProbs(probs->mutable_float_array());
38 }
39
40 // Builds a target using CTC. Slightly improved as follows:
41 // Includes normalizations and clipping for stability.
42 // labels should be pre-padded with nulls wherever desired, but they don't
43 // have to be between all labels. Allows for multi-label codes with no
44 // nulls between.
45 // labels can be longer than the time sequence, but the total number of
46 // essential labels (non-null plus nulls between equal labels) must not exceed
47 // the number of timesteps in outputs.
48 // outputs is the output of the network, and should have already been
49 // normalized with NormalizeProbs.
50 // On return targets is filled with the computed targets.
51 // Returns false if there is insufficient time for the labels.
52 static bool ComputeCTCTargets(const std::vector<int> &truth_labels, int null_char,
53 const GENERIC_2D_ARRAY<float> &outputs, NetworkIO *targets);
54
55 private:
56 // Constructor is private as the instance only holds information specific to
57 // the current labels, outputs etc, and is built by the static function.
58 CTC(const std::vector<int> &labels, int null_char, const GENERIC_2D_ARRAY<float> &outputs);
59
60 // Computes vectors of min and max label index for each timestep, based on
61 // whether skippability of nulls makes it possible to complete a valid path.
62 bool ComputeLabelLimits();
63 // Computes targets based purely on the labels by spreading the labels evenly
64 // over the available timesteps.
65 void ComputeSimpleTargets(GENERIC_2D_ARRAY<float> *targets) const;
66 // Computes mean positions and half widths of the simple targets by spreading
67 // the labels even over the available timesteps.
68 void ComputeWidthsAndMeans(std::vector<float> *half_widths, std::vector<int> *means) const;
69 // Calculates and returns a suitable fraction of the simple targets to add
70 // to the network outputs.
71 float CalculateBiasFraction();
72 // Runs the forward CTC pass, filling in log_probs.
73 void Forward(GENERIC_2D_ARRAY<double> *log_probs) const;
74 // Runs the backward CTC pass, filling in log_probs.
75 void Backward(GENERIC_2D_ARRAY<double> *log_probs) const;
76 // Normalizes and brings probs out of log space with a softmax over time.
77 void NormalizeSequence(GENERIC_2D_ARRAY<double> *probs) const;
78 // For each timestep computes the max prob for each class over all
79 // instances of the class in the labels_, and sets the targets to
80 // the max observed prob.
81 void LabelsToClasses(const GENERIC_2D_ARRAY<double> &probs, NetworkIO *targets) const;
82 // Normalizes the probabilities such that no target has a prob below min_prob,
83 // and, provided that the initial total is at least min_total_prob, then all
84 // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
85 // probability is thus 1 - (num_classes-1)*min_prob.
86 static void NormalizeProbs(GENERIC_2D_ARRAY<float> *probs);
87 // Returns true if the label at index is a needed null.
88 bool NeededNull(int index) const;
89 // Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
90 // underflow.
91 static double ClippedExp(double x) {
92 if (x < -kMaxExpArg_) {
93 return exp(-kMaxExpArg_);
94 }
95 if (x > kMaxExpArg_) {
96 return exp(kMaxExpArg_);
97 }
98 return exp(x);
99 }
100
101 // Minimum probability limit for softmax input to ctc_loss.
102 static const float kMinProb_;
103 // Maximum absolute argument to exp().
104 static const double kMaxExpArg_;
105 // Minimum probability for total prob in time normalization.
106 static const double kMinTotalTimeProb_;
107 // Minimum probability for total prob in final normalization.
108 static const double kMinTotalFinalProb_;
109
110 // The truth label indices that are to be matched to outputs_.
111 const std::vector<int> &labels_;
112 // The network outputs.
113 GENERIC_2D_ARRAY<float> outputs_;
114 // The null or "blank" label.
115 int null_char_;
116 // Number of timesteps in outputs_.
117 int num_timesteps_;
118 // Number of classes in outputs_.
119 int num_classes_;
120 // Number of labels in labels_.
121 int num_labels_;
122 // Min and max valid label indices for each timestep.
123 std::vector<int> min_labels_;
124 std::vector<int> max_labels_;
125 };
126
127 } // namespace tesseract
128
129 #endif // TESSERACT_LSTM_CTC_H_