comparison mupdf-source/thirdparty/tesseract/src/lstm/fullyconnected.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: fullyconnected.cpp
3 // Description: Simple feed-forward layer with various non-linearities.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2014, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #ifdef HAVE_CONFIG_H
19 # include "config_auto.h"
20 #endif
21
22 #include "fullyconnected.h"
23
24 #ifdef _OPENMP
25 # include <omp.h>
26 #endif
27 #include <cstdio>
28 #include <cstdlib>
29
30 #include "functions.h"
31 #include "networkscratch.h"
32
33 // Number of threads to use for parallel calculation of Forward and Backward.
34 #ifdef _OPENMP
35 const int kNumThreads = 4;
36 #else
37 const int kNumThreads = 1;
38 #endif
39
40 namespace tesseract {
41
42 FullyConnected::FullyConnected(const std::string &name, int ni, int no, NetworkType type)
43 : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {}
44
45 // Returns the shape output from the network given an input shape (which may
46 // be partially unknown ie zero).
47 StaticShape FullyConnected::OutputShape(const StaticShape &input_shape) const {
48 LossType loss_type = LT_NONE;
49 if (type_ == NT_SOFTMAX) {
50 loss_type = LT_CTC;
51 } else if (type_ == NT_SOFTMAX_NO_CTC) {
52 loss_type = LT_SOFTMAX;
53 } else if (type_ == NT_LOGISTIC) {
54 loss_type = LT_LOGISTIC;
55 }
56 StaticShape result(input_shape);
57 result.set_depth(no_);
58 result.set_loss_type(loss_type);
59 return result;
60 }
61
62 // Suspends/Enables training by setting the training_ flag.
63 void FullyConnected::SetEnableTraining(TrainingState state) {
64 if (state == TS_RE_ENABLE) {
65 // Enable only from temp disabled.
66 if (training_ == TS_TEMP_DISABLE) {
67 training_ = TS_ENABLED;
68 }
69 } else if (state == TS_TEMP_DISABLE) {
70 // Temp disable only from enabled.
71 if (training_ == TS_ENABLED) {
72 training_ = state;
73 }
74 } else {
75 if (state == TS_ENABLED && training_ != TS_ENABLED) {
76 weights_.InitBackward();
77 }
78 training_ = state;
79 }
80 }
81
82 // Sets up the network for training. Initializes weights using weights of
83 // scale `range` picked according to the random number generator `randomizer`.
84 int FullyConnected::InitWeights(float range, TRand *randomizer) {
85 Network::SetRandomizer(randomizer);
86 num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM), range, randomizer);
87 return num_weights_;
88 }
89
90 // Recursively searches the network for softmaxes with old_no outputs,
91 // and remaps their outputs according to code_map. See network.h for details.
92
93 int FullyConnected::RemapOutputs(int old_no, const std::vector<int> &code_map) {
94 if (type_ == NT_SOFTMAX && no_ == old_no) {
95 num_weights_ = weights_.RemapOutputs(code_map);
96 no_ = code_map.size();
97 }
98 return num_weights_;
99 }
100
101 // Converts a float network to an int network.
102 void FullyConnected::ConvertToInt() {
103 weights_.ConvertToInt();
104 }
105
106 // Provides debug output on the weights.
107 void FullyConnected::DebugWeights() {
108 weights_.Debug2D(name_.c_str());
109 }
110
111 // Writes to the given file. Returns false in case of error.
112 bool FullyConnected::Serialize(TFile *fp) const {
113 if (!Network::Serialize(fp)) {
114 return false;
115 }
116 if (!weights_.Serialize(IsTraining(), fp)) {
117 return false;
118 }
119 return true;
120 }
121
122 // Reads from the given file. Returns false in case of error.
123 bool FullyConnected::DeSerialize(TFile *fp) {
124 return weights_.DeSerialize(IsTraining(), fp);
125 }
126
127 // Runs forward propagation of activations on the input line.
128 // See NetworkCpp for a detailed discussion of the arguments.
129 void FullyConnected::Forward(bool debug, const NetworkIO &input,
130 const TransposedArray *input_transpose, NetworkScratch *scratch,
131 NetworkIO *output) {
132 int width = input.Width();
133 if (type_ == NT_SOFTMAX) {
134 output->ResizeFloat(input, no_);
135 } else {
136 output->Resize(input, no_);
137 }
138 SetupForward(input, input_transpose);
139 std::vector<NetworkScratch::FloatVec> temp_lines(kNumThreads);
140 std::vector<NetworkScratch::FloatVec> curr_input(kNumThreads);
141 int ro = no_;
142 if (IntSimdMatrix::intSimdMatrix) {
143 ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
144 }
145 for (int i = 0; i < kNumThreads; ++i) {
146 temp_lines[i].Init(ro, scratch);
147 curr_input[i].Init(ni_, scratch);
148 }
149 #ifdef _OPENMP
150 # pragma omp parallel for num_threads(kNumThreads)
151 for (int t = 0; t < width; ++t) {
152 // Thread-local pointer to temporary storage.
153 int thread_id = omp_get_thread_num();
154 #else
155 for (int t = 0; t < width; ++t) {
156 // Thread-local pointer to temporary storage.
157 int thread_id = 0;
158 #endif
159 TFloat *temp_line = temp_lines[thread_id];
160 if (input.int_mode()) {
161 ForwardTimeStep(input.i(t), t, temp_line);
162 } else {
163 input.ReadTimeStep(t, curr_input[thread_id]);
164 ForwardTimeStep(curr_input[thread_id], t, temp_line);
165 }
166 output->WriteTimeStep(t, temp_line);
167 if (IsTraining() && type_ != NT_SOFTMAX) {
168 acts_.CopyTimeStepFrom(t, *output, t);
169 }
170 }
171 // Zero all the elements that are in the padding around images that allows
172 // multiple different-sized images to exist in a single array.
173 // acts_ is only used if this is not a softmax op.
174 if (IsTraining() && type_ != NT_SOFTMAX) {
175 acts_.ZeroInvalidElements();
176 }
177 output->ZeroInvalidElements();
178 #if DEBUG_DETAIL > 0
179 tprintf("F Output:%s\n", name_.c_str());
180 output->Print(10);
181 #endif
182 #ifndef GRAPHICS_DISABLED
183 if (debug) {
184 DisplayForward(*output);
185 }
186 #endif
187 }
188
189 // Components of Forward so FullyConnected can be reused inside LSTM.
190 void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray *input_transpose) {
191 // Softmax output is always float, so save the input type.
192 int_mode_ = input.int_mode();
193 if (IsTraining()) {
194 acts_.Resize(input, no_);
195 // Source_ is a transposed copy of input. It isn't needed if provided.
196 external_source_ = input_transpose;
197 if (external_source_ == nullptr) {
198 source_t_.ResizeNoInit(ni_, input.Width());
199 }
200 }
201 }
202
203 void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) {
204 if (type_ == NT_TANH) {
205 FuncInplace<GFunc>(no_, output_line);
206 } else if (type_ == NT_LOGISTIC) {
207 FuncInplace<FFunc>(no_, output_line);
208 } else if (type_ == NT_POSCLIP) {
209 FuncInplace<ClipFFunc>(no_, output_line);
210 } else if (type_ == NT_SYMCLIP) {
211 FuncInplace<ClipGFunc>(no_, output_line);
212 } else if (type_ == NT_RELU) {
213 FuncInplace<Relu>(no_, output_line);
214 } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
215 SoftmaxInPlace(no_, output_line);
216 } else if (type_ != NT_LINEAR) {
217 ASSERT_HOST("Invalid fully-connected type!" == nullptr);
218 }
219 }
220
221 void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) {
222 // input is copied to source_ line-by-line for cache coherency.
223 if (IsTraining() && external_source_ == nullptr) {
224 source_t_.WriteStrided(t, d_input);
225 }
226 weights_.MatrixDotVector(d_input, output_line);
227 ForwardTimeStep(t, output_line);
228 }
229
230 void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) {
231 // input is copied to source_ line-by-line for cache coherency.
232 weights_.MatrixDotVector(i_input, output_line);
233 ForwardTimeStep(t, output_line);
234 }
235
236 // Runs backward propagation of errors on the deltas line.
237 // See NetworkCpp for a detailed discussion of the arguments.
238 bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
239 NetworkIO *back_deltas) {
240 #ifndef GRAPHICS_DISABLED
241 if (debug) {
242 DisplayBackward(fwd_deltas);
243 }
244 #endif
245 back_deltas->Resize(fwd_deltas, ni_);
246 std::vector<NetworkScratch::FloatVec> errors(kNumThreads);
247 for (int i = 0; i < kNumThreads; ++i) {
248 errors[i].Init(no_, scratch);
249 }
250 std::vector<NetworkScratch::FloatVec> temp_backprops;
251 if (needs_to_backprop_) {
252 temp_backprops.resize(kNumThreads);
253 for (int i = 0; i < kNumThreads; ++i) {
254 temp_backprops[i].Init(ni_, scratch);
255 }
256 }
257 int width = fwd_deltas.Width();
258 NetworkScratch::GradientStore errors_t;
259 errors_t.Init(no_, width, scratch);
260 #ifdef _OPENMP
261 # pragma omp parallel for num_threads(kNumThreads)
262 for (int t = 0; t < width; ++t) {
263 int thread_id = omp_get_thread_num();
264 #else
265 for (int t = 0; t < width; ++t) {
266 int thread_id = 0;
267 #endif
268 TFloat *backprop = nullptr;
269 if (needs_to_backprop_) {
270 backprop = temp_backprops[thread_id];
271 }
272 TFloat *curr_errors = errors[thread_id];
273 BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
274 if (backprop != nullptr) {
275 back_deltas->WriteTimeStep(t, backprop);
276 }
277 }
278 FinishBackward(*errors_t.get());
279 if (needs_to_backprop_) {
280 back_deltas->ZeroInvalidElements();
281 #if DEBUG_DETAIL > 0
282 tprintf("F Backprop:%s\n", name_.c_str());
283 back_deltas->Print(10);
284 #endif
285 return true;
286 }
287 return false; // No point going further back.
288 }
289
290 void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors,
291 TransposedArray *errors_t, TFloat *backprop) {
292 if (type_ == NT_TANH) {
293 acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
294 } else if (type_ == NT_LOGISTIC) {
295 acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
296 } else if (type_ == NT_POSCLIP) {
297 acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
298 } else if (type_ == NT_SYMCLIP) {
299 acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
300 } else if (type_ == NT_RELU) {
301 acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
302 } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || type_ == NT_LINEAR) {
303 fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
304 } else {
305 ASSERT_HOST("Invalid fully-connected type!" == nullptr);
306 }
307 // Generate backprop only if needed by the lower layer.
308 if (backprop != nullptr) {
309 weights_.VectorDotMatrix(curr_errors, backprop);
310 }
311 errors_t->WriteStrided(t, curr_errors);
312 }
313
314 void FullyConnected::FinishBackward(const TransposedArray &errors_t) {
315 if (external_source_ == nullptr) {
316 weights_.SumOuterTransposed(errors_t, source_t_, true);
317 } else {
318 weights_.SumOuterTransposed(errors_t, *external_source_, true);
319 }
320 }
321
322 // Updates the weights using the given learning rate, momentum and adam_beta.
323 // num_samples is used in the adam computation iff use_adam_ is true.
324 void FullyConnected::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
325 weights_.Update(learning_rate, momentum, adam_beta, num_samples);
326 }
327
328 // Sums the products of weight updates in *this and other, splitting into
329 // positive (same direction) in *same and negative (different direction) in
330 // *changed.
331 void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
332 ASSERT_HOST(other.type() == type_);
333 const auto *fc = static_cast<const FullyConnected *>(&other);
334 weights_.CountAlternators(fc->weights_, same, changed);
335 }
336
337 } // namespace tesseract.