comparison mupdf-source/thirdparty/tesseract/src/training/common/networkbuilder.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: networkbuilder.cpp
3 // Description: Class to parse the network description language and
4 // build a corresponding network.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////
18
19 #include "networkbuilder.h"
20
21 #include "convolve.h"
22 #include "fullyconnected.h"
23 #include "input.h"
24 #include "lstm.h"
25 #include "maxpool.h"
26 #include "network.h"
27 #include "parallel.h"
28 #include "reconfig.h"
29 #include "reversed.h"
30 #include "series.h"
31 #include "unicharset.h"
32
33 namespace tesseract {
34
35 // Builds a network with a network_spec in the network description
36 // language, to recognize a character set of num_outputs size.
37 // If append_index is non-negative, then *network must be non-null and the
38 // given network_spec will be appended to *network AFTER append_index, with
39 // the top of the input *network discarded.
40 // Note that network_spec is call by value to allow a non-const char* pointer
41 // into the string for BuildFromString.
42 // net_flags control network behavior according to the NetworkFlags enum.
43 // The resulting network is returned via **network.
44 // Returns false if something failed.
45 bool NetworkBuilder::InitNetwork(int num_outputs, const char *network_spec, int append_index,
46 int net_flags, float weight_range, TRand *randomizer,
47 Network **network) {
48 NetworkBuilder builder(num_outputs);
49 Series *bottom_series = nullptr;
50 StaticShape input_shape;
51 if (append_index >= 0) {
52 // Split the current network after the given append_index.
53 ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
54 auto *series = static_cast<Series *>(*network);
55 Series *top_series = nullptr;
56 series->SplitAt(append_index, &bottom_series, &top_series);
57 if (bottom_series == nullptr || top_series == nullptr) {
58 tprintf("Yikes! Splitting current network failed!!\n");
59 return false;
60 }
61 input_shape = bottom_series->OutputShape(input_shape);
62 delete top_series;
63 }
64 *network = builder.BuildFromString(input_shape, &network_spec);
65 if (*network == nullptr) {
66 return false;
67 }
68 (*network)->SetNetworkFlags(net_flags);
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(false);
71 if (bottom_series != nullptr) {
72 bottom_series->AppendSeries(*network);
73 *network = bottom_series;
74 }
75 (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76 return true;
77 }
78
79 // Helper skips whitespace.
80 static void SkipWhitespace(const char **str) {
81 while (**str == ' ' || **str == '\t' || **str == '\n') {
82 ++*str;
83 }
84 }
85
86 // Parses the given string and returns a network according to the network
87 // description language in networkbuilder.h
88 Network *NetworkBuilder::BuildFromString(const StaticShape &input_shape, const char **str) {
89 SkipWhitespace(str);
90 char code_ch = **str;
91 if (code_ch == '[') {
92 return ParseSeries(input_shape, nullptr, str);
93 }
94 if (input_shape.depth() == 0) {
95 // There must be an input at this point.
96 return ParseInput(str);
97 }
98 switch (code_ch) {
99 case '(':
100 return ParseParallel(input_shape, str);
101 case 'R':
102 return ParseR(input_shape, str);
103 case 'S':
104 return ParseS(input_shape, str);
105 case 'C':
106 return ParseC(input_shape, str);
107 case 'M':
108 return ParseM(input_shape, str);
109 case 'L':
110 return ParseLSTM(input_shape, str);
111 case 'F':
112 return ParseFullyConnected(input_shape, str);
113 case 'O':
114 return ParseOutput(input_shape, str);
115 default:
116 tprintf("Invalid network spec:%s\n", *str);
117 }
118 return nullptr;
119 }
120
121 // Parses an input specification and returns the result, which may include a
122 // series.
123 Network *NetworkBuilder::ParseInput(const char **str) {
124 // There must be an input at this point.
125 int length = 0;
126 int batch, height, width, depth;
127 int num_converted = sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
128 StaticShape shape;
129 shape.SetShape(batch, height, width, depth);
130 // num_converted may or may not include the length.
131 if (num_converted != 4 && num_converted != 5) {
132 tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
133 return nullptr;
134 }
135 *str += length;
136 auto *input = new Input("Input", shape);
137 // We want to allow [<input>rest of net... or <input>[rest of net... so we
138 // have to check explicitly for '[' here.
139 SkipWhitespace(str);
140 if (**str == '[') {
141 return ParseSeries(shape, input, str);
142 }
143 return input;
144 }
145
146 // Parses a sequential series of networks, defined by [<net><net>...].
147 Network *NetworkBuilder::ParseSeries(const StaticShape &input_shape, Input *input_layer,
148 const char **str) {
149 StaticShape shape = input_shape;
150 auto *series = new Series("Series");
151 ++*str;
152 if (input_layer != nullptr) {
153 series->AddToStack(input_layer);
154 shape = input_layer->OutputShape(shape);
155 }
156 Network *network = nullptr;
157 while (**str != '\0' && **str != ']' && (network = BuildFromString(shape, str)) != nullptr) {
158 shape = network->OutputShape(shape);
159 series->AddToStack(network);
160 }
161 if (**str != ']') {
162 tprintf("Missing ] at end of [Series]!\n");
163 delete series;
164 return nullptr;
165 }
166 ++*str;
167 return series;
168 }
169
170 // Parses a parallel set of networks, defined by (<net><net>...).
171 Network *NetworkBuilder::ParseParallel(const StaticShape &input_shape, const char **str) {
172 auto *parallel = new Parallel("Parallel", NT_PARALLEL);
173 ++*str;
174 Network *network = nullptr;
175 while (**str != '\0' && **str != ')' &&
176 (network = BuildFromString(input_shape, str)) != nullptr) {
177 parallel->AddToStack(network);
178 }
179 if (**str != ')') {
180 tprintf("Missing ) at end of (Parallel)!\n");
181 delete parallel;
182 return nullptr;
183 }
184 ++*str;
185 return parallel;
186 }
187
188 // Parses a network that begins with 'R'.
189 Network *NetworkBuilder::ParseR(const StaticShape &input_shape, const char **str) {
190 char dir = (*str)[1];
191 if (dir == 'x' || dir == 'y') {
192 std::string name = "Reverse";
193 name += dir;
194 *str += 2;
195 Network *network = BuildFromString(input_shape, str);
196 if (network == nullptr) {
197 return nullptr;
198 }
199 auto *rev = new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
200 rev->SetNetwork(network);
201 return rev;
202 }
203 char *end;
204 int replicas = strtol(*str + 1, &end, 10);
205 *str = end;
206 if (replicas <= 0) {
207 tprintf("Invalid R spec!:%s\n", end);
208 return nullptr;
209 }
210 auto *parallel = new Parallel("Replicated", NT_REPLICATED);
211 const char *str_copy = *str;
212 for (int i = 0; i < replicas; ++i) {
213 str_copy = *str;
214 Network *network = BuildFromString(input_shape, &str_copy);
215 if (network == nullptr) {
216 tprintf("Invalid replicated network!\n");
217 delete parallel;
218 return nullptr;
219 }
220 parallel->AddToStack(network);
221 }
222 *str = str_copy;
223 return parallel;
224 }
225
226 // Parses a network that begins with 'S'.
227 Network *NetworkBuilder::ParseS(const StaticShape &input_shape, const char **str) {
228 char *end;
229 int y = strtol(*str + 1, &end, 10);
230 *str = end;
231 if (**str == ',') {
232 int x = strtol(*str + 1, &end, 10);
233 *str = end;
234 if (y <= 0 || x <= 0) {
235 tprintf("Invalid S spec!:%s\n", *str);
236 return nullptr;
237 }
238 return new Reconfig("Reconfig", input_shape.depth(), x, y);
239 } else if (**str == '(') {
240 // TODO(rays) Add Generic reshape.
241 tprintf("Generic reshape not yet implemented!!\n");
242 return nullptr;
243 }
244 tprintf("Invalid S spec!:%s\n", *str);
245 return nullptr;
246 }
247
248 // Helper returns the fully-connected type for the character code.
249 static NetworkType NonLinearity(char func) {
250 switch (func) {
251 case 's':
252 return NT_LOGISTIC;
253 case 't':
254 return NT_TANH;
255 case 'r':
256 return NT_RELU;
257 case 'l':
258 return NT_LINEAR;
259 case 'm':
260 return NT_SOFTMAX;
261 case 'p':
262 return NT_POSCLIP;
263 case 'n':
264 return NT_SYMCLIP;
265 default:
266 return NT_NONE;
267 }
268 }
269
270 // Parses a network that begins with 'C'.
271 Network *NetworkBuilder::ParseC(const StaticShape &input_shape, const char **str) {
272 NetworkType type = NonLinearity((*str)[1]);
273 if (type == NT_NONE) {
274 tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
275 return nullptr;
276 }
277 int y = 0, x = 0, d = 0;
278 char *end;
279 if ((y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
280 (x = strtol(end + 1, &end, 10)) <= 0 || *end != ',' || (d = strtol(end + 1, &end, 10)) <= 0) {
281 tprintf("Invalid C spec!:%s\n", end);
282 return nullptr;
283 }
284 *str = end;
285 if (x == 1 && y == 1) {
286 // No actual convolution. Just a FullyConnected on the current depth, to
287 // be slid over all batch,y,x.
288 return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
289 }
290 auto *series = new Series("ConvSeries");
291 auto *convolve = new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
292 series->AddToStack(convolve);
293 StaticShape fc_input = convolve->OutputShape(input_shape);
294 series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
295 return series;
296 }
297
298 // Parses a network that begins with 'M'.
299 Network *NetworkBuilder::ParseM(const StaticShape &input_shape, const char **str) {
300 int y = 0, x = 0;
301 char *end;
302 if ((*str)[1] != 'p' || (y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
303 (x = strtol(end + 1, &end, 10)) <= 0) {
304 tprintf("Invalid Mp spec!:%s\n", *str);
305 return nullptr;
306 }
307 *str = end;
308 return new Maxpool("Maxpool", input_shape.depth(), x, y);
309 }
310
311 // Parses an LSTM network, either individual, bi- or quad-directional.
312 Network *NetworkBuilder::ParseLSTM(const StaticShape &input_shape, const char **str) {
313 bool two_d = false;
314 NetworkType type = NT_LSTM;
315 const char *spec_start = *str;
316 int chars_consumed = 1;
317 int num_outputs = 0;
318 char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
319 if (key == 'S') {
320 type = NT_LSTM_SOFTMAX;
321 num_outputs = num_softmax_outputs_;
322 ++chars_consumed;
323 } else if (key == 'E') {
324 type = NT_LSTM_SOFTMAX_ENCODED;
325 num_outputs = num_softmax_outputs_;
326 ++chars_consumed;
327 } else if (key == '2' &&
328 (((*str)[2] == 'x' && (*str)[3] == 'y') || ((*str)[2] == 'y' && (*str)[3] == 'x'))) {
329 chars_consumed = 4;
330 dim = (*str)[3];
331 two_d = true;
332 } else if (key == 'f' || key == 'r' || key == 'b') {
333 dir = key;
334 dim = (*str)[2];
335 if (dim != 'x' && dim != 'y') {
336 tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
337 return nullptr;
338 }
339 chars_consumed = 3;
340 if ((*str)[chars_consumed] == 's') {
341 ++chars_consumed;
342 type = NT_LSTM_SUMMARY;
343 }
344 } else {
345 tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
346 return nullptr;
347 }
348 char *end;
349 int num_states = strtol(*str + chars_consumed, &end, 10);
350 if (num_states <= 0) {
351 tprintf("Invalid number of states in L Spec!:%s\n", *str);
352 return nullptr;
353 }
354 *str = end;
355 Network *lstm = nullptr;
356 if (two_d) {
357 lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
358 } else {
359 if (num_outputs == 0) {
360 num_outputs = num_states;
361 }
362 std::string name(spec_start, *str - spec_start);
363 lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type);
364 if (dir != 'f') {
365 auto *rev = new Reversed("RevLSTM", NT_XREVERSED);
366 rev->SetNetwork(lstm);
367 lstm = rev;
368 }
369 if (dir == 'b') {
370 name += "LTR";
371 auto *parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
372 parallel->AddToStack(
373 new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type));
374 parallel->AddToStack(lstm);
375 lstm = parallel;
376 }
377 }
378 if (dim == 'y') {
379 auto *rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
380 rev->SetNetwork(lstm);
381 lstm = rev;
382 }
383 return lstm;
384 }
385
386 // Builds a set of 4 lstms with x and y reversal, running in true parallel.
387 Network *NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
388 auto *parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
389 parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
390 auto *rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
391 rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, true, NT_LSTM));
392 parallel->AddToStack(rev);
393 rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
394 rev->SetNetwork(new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
395 auto *rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
396 rev2->SetNetwork(rev);
397 parallel->AddToStack(rev2);
398 rev = new Reversed("L2DXRevY", NT_YREVERSED);
399 rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
400 parallel->AddToStack(rev);
401 return parallel;
402 }
403
404 // Helper builds a truly (0-d) fully connected layer of the given type.
405 static Network *BuildFullyConnected(const StaticShape &input_shape, NetworkType type,
406 const std::string &name, int depth) {
407 if (input_shape.height() == 0 || input_shape.width() == 0) {
408 tprintf("Fully connected requires positive height and width, had %d,%d\n", input_shape.height(),
409 input_shape.width());
410 return nullptr;
411 }
412 int input_size = input_shape.height() * input_shape.width();
413 int input_depth = input_size * input_shape.depth();
414 Network *fc = new FullyConnected(name, input_depth, depth, type);
415 if (input_size > 1) {
416 auto *series = new Series("FCSeries");
417 series->AddToStack(
418 new Reconfig("FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height()));
419 series->AddToStack(fc);
420 fc = series;
421 }
422 return fc;
423 }
424
425 // Parses a Fully connected network.
426 Network *NetworkBuilder::ParseFullyConnected(const StaticShape &input_shape, const char **str) {
427 const char *spec_start = *str;
428 NetworkType type = NonLinearity((*str)[1]);
429 if (type == NT_NONE) {
430 tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
431 return nullptr;
432 }
433 char *end;
434 int depth = strtol(*str + 2, &end, 10);
435 if (depth <= 0) {
436 tprintf("Invalid F spec!:%s\n", *str);
437 return nullptr;
438 }
439 *str = end;
440 std::string name(spec_start, *str - spec_start);
441 return BuildFullyConnected(input_shape, type, name, depth);
442 }
443
444 // Parses an Output spec.
445 Network *NetworkBuilder::ParseOutput(const StaticShape &input_shape, const char **str) {
446 char dims_ch = (*str)[1];
447 if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
448 tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
449 return nullptr;
450 }
451 char type_ch = (*str)[2];
452 if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
453 tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
454 return nullptr;
455 }
456 char *end;
457 int depth = strtol(*str + 3, &end, 10);
458 if (depth != num_softmax_outputs_) {
459 tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
460 num_softmax_outputs_);
461 depth = num_softmax_outputs_;
462 }
463 *str = end;
464 NetworkType type = NT_SOFTMAX;
465 if (type_ch == 'l') {
466 type = NT_LOGISTIC;
467 } else if (type_ch == 's') {
468 type = NT_SOFTMAX_NO_CTC;
469 }
470 if (dims_ch == '0') {
471 // Same as standard fully connected.
472 return BuildFullyConnected(input_shape, type, "Output", depth);
473 } else if (dims_ch == '2') {
474 // We don't care if x and/or y are variable.
475 return new FullyConnected("Output2d", input_shape.depth(), depth, type);
476 }
477 // For 1-d y has to be fixed, and if not 1, moved to depth.
478 if (input_shape.height() == 0) {
479 tprintf("Fully connected requires fixed height!\n");
480 return nullptr;
481 }
482 int input_size = input_shape.height();
483 int input_depth = input_size * input_shape.depth();
484 Network *fc = new FullyConnected("Output", input_depth, depth, type);
485 if (input_size > 1) {
486 auto *series = new Series("FCSeries");
487 series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, input_shape.height()));
488 series->AddToStack(fc);
489 fc = series;
490 }
491 return fc;
492 }
493
494 } // namespace tesseract.