Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/lstm.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: lstm.cpp | |
| 3 // Description: Long-term-short-term-memory Recurrent neural network. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2013, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #ifdef HAVE_CONFIG_H | |
| 19 # include "config_auto.h" | |
| 20 #endif | |
| 21 | |
| 22 #include "lstm.h" | |
| 23 | |
| 24 #ifdef _OPENMP | |
| 25 # include <omp.h> | |
| 26 #endif | |
| 27 #include <cstdio> | |
| 28 #include <cstdlib> | |
| 29 #include <sstream> // for std::ostringstream | |
| 30 | |
| 31 #if defined(_MSC_VER) && !defined(__clang__) | |
| 32 # include <intrin.h> // _BitScanReverse | |
| 33 #endif | |
| 34 | |
| 35 #include "fullyconnected.h" | |
| 36 #include "functions.h" | |
| 37 #include "networkscratch.h" | |
| 38 #include "tprintf.h" | |
| 39 | |
| 40 // Macros for openmp code if it is available, otherwise empty macros. | |
| 41 #ifdef _OPENMP | |
| 42 # define PARALLEL_IF_OPENMP(__num_threads) \ | |
| 43 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \ | |
| 44 PRAGMA(omp sections nowait) { \ | |
| 45 PRAGMA(omp section) { | |
| 46 # define SECTION_IF_OPENMP \ | |
| 47 } \ | |
| 48 PRAGMA(omp section) { | |
| 49 # define END_PARALLEL_IF_OPENMP \ | |
| 50 } \ | |
| 51 } /* end of sections */ \ | |
| 52 } /* end of parallel section */ | |
| 53 | |
| 54 // Define the portable PRAGMA macro. | |
| 55 # ifdef _MSC_VER // Different _Pragma | |
| 56 # define PRAGMA(x) __pragma(x) | |
| 57 # else | |
| 58 # define PRAGMA(x) _Pragma(# x) | |
| 59 # endif // _MSC_VER | |
| 60 | |
| 61 #else // _OPENMP | |
| 62 # define PARALLEL_IF_OPENMP(__num_threads) | |
| 63 # define SECTION_IF_OPENMP | |
| 64 # define END_PARALLEL_IF_OPENMP | |
| 65 #endif // _OPENMP | |
| 66 | |
| 67 namespace tesseract { | |
| 68 | |
| 69 // Max absolute value of state_. It is reasonably high to enable the state | |
| 70 // to count things. | |
| 71 const TFloat kStateClip = 100.0; | |
| 72 // Max absolute value of gate_errors (the gradients). | |
| 73 const TFloat kErrClip = 1.0f; | |
| 74 | |
| 75 // Calculate ceil(log2(n)). | |
| 76 static inline uint32_t ceil_log2(uint32_t n) { | |
| 77 // l2 = (unsigned)log2(n). | |
| 78 #if defined(__GNUC__) | |
| 79 // Use fast inline assembler code for gcc or clang. | |
| 80 uint32_t l2 = 31 - __builtin_clz(n); | |
| 81 #elif defined(_MSC_VER) | |
| 82 // Use fast intrinsic function for MS compiler. | |
| 83 unsigned long l2 = 0; | |
| 84 _BitScanReverse(&l2, n); | |
| 85 #else | |
| 86 if (n == 0) | |
| 87 return UINT_MAX; | |
| 88 if (n == 1) | |
| 89 return 0; | |
| 90 uint32_t val = n; | |
| 91 uint32_t l2 = 0; | |
| 92 while (val > 1) { | |
| 93 val >>= 1; | |
| 94 l2++; | |
| 95 } | |
| 96 #endif | |
| 97 // Round up if n is not a power of 2. | |
| 98 return (n == (1u << l2)) ? l2 : l2 + 1; | |
| 99 } | |
| 100 | |
| 101 LSTM::LSTM(const std::string &name, int ni, int ns, int no, bool two_dimensional, NetworkType type) | |
| 102 : Network(type, name, ni, no) | |
| 103 , na_(ni + ns) | |
| 104 , ns_(ns) | |
| 105 , nf_(0) | |
| 106 , is_2d_(two_dimensional) | |
| 107 , softmax_(nullptr) | |
| 108 , input_width_(0) { | |
| 109 if (two_dimensional) { | |
| 110 na_ += ns_; | |
| 111 } | |
| 112 if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) { | |
| 113 nf_ = 0; | |
| 114 // networkbuilder ensures this is always true. | |
| 115 ASSERT_HOST(no == ns); | |
| 116 } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) { | |
| 117 nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_); | |
| 118 softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX); | |
| 119 } else { | |
| 120 tprintf("%d is invalid type of LSTM!\n", type); | |
| 121 ASSERT_HOST(false); | |
| 122 } | |
| 123 na_ += nf_; | |
| 124 } | |
| 125 | |
| 126 LSTM::~LSTM() { | |
| 127 delete softmax_; | |
| 128 } | |
| 129 | |
| 130 // Returns the shape output from the network given an input shape (which may | |
| 131 // be partially unknown ie zero). | |
| 132 StaticShape LSTM::OutputShape(const StaticShape &input_shape) const { | |
| 133 StaticShape result = input_shape; | |
| 134 result.set_depth(no_); | |
| 135 if (type_ == NT_LSTM_SUMMARY) { | |
| 136 result.set_width(1); | |
| 137 } | |
| 138 if (softmax_ != nullptr) { | |
| 139 return softmax_->OutputShape(result); | |
| 140 } | |
| 141 return result; | |
| 142 } | |
| 143 | |
| 144 // Suspends/Enables training by setting the training_ flag. Serialize and | |
| 145 // DeSerialize only operate on the run-time data if state is false. | |
| 146 void LSTM::SetEnableTraining(TrainingState state) { | |
| 147 if (state == TS_RE_ENABLE) { | |
| 148 // Enable only from temp disabled. | |
| 149 if (training_ == TS_TEMP_DISABLE) { | |
| 150 training_ = TS_ENABLED; | |
| 151 } | |
| 152 } else if (state == TS_TEMP_DISABLE) { | |
| 153 // Temp disable only from enabled. | |
| 154 if (training_ == TS_ENABLED) { | |
| 155 training_ = state; | |
| 156 } | |
| 157 } else { | |
| 158 if (state == TS_ENABLED && training_ != TS_ENABLED) { | |
| 159 for (int w = 0; w < WT_COUNT; ++w) { | |
| 160 if (w == GFS && !Is2D()) { | |
| 161 continue; | |
| 162 } | |
| 163 gate_weights_[w].InitBackward(); | |
| 164 } | |
| 165 } | |
| 166 training_ = state; | |
| 167 } | |
| 168 if (softmax_ != nullptr) { | |
| 169 softmax_->SetEnableTraining(state); | |
| 170 } | |
| 171 } | |
| 172 | |
| 173 // Sets up the network for training. Initializes weights using weights of | |
| 174 // scale `range` picked according to the random number generator `randomizer`. | |
| 175 int LSTM::InitWeights(float range, TRand *randomizer) { | |
| 176 Network::SetRandomizer(randomizer); | |
| 177 num_weights_ = 0; | |
| 178 for (int w = 0; w < WT_COUNT; ++w) { | |
| 179 if (w == GFS && !Is2D()) { | |
| 180 continue; | |
| 181 } | |
| 182 num_weights_ += | |
| 183 gate_weights_[w].InitWeightsFloat(ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer); | |
| 184 } | |
| 185 if (softmax_ != nullptr) { | |
| 186 num_weights_ += softmax_->InitWeights(range, randomizer); | |
| 187 } | |
| 188 return num_weights_; | |
| 189 } | |
| 190 | |
| 191 // Recursively searches the network for softmaxes with old_no outputs, | |
| 192 // and remaps their outputs according to code_map. See network.h for details. | |
| 193 int LSTM::RemapOutputs(int old_no, const std::vector<int> &code_map) { | |
| 194 if (softmax_ != nullptr) { | |
| 195 num_weights_ -= softmax_->num_weights(); | |
| 196 num_weights_ += softmax_->RemapOutputs(old_no, code_map); | |
| 197 } | |
| 198 return num_weights_; | |
| 199 } | |
| 200 | |
| 201 // Converts a float network to an int network. | |
| 202 void LSTM::ConvertToInt() { | |
| 203 for (int w = 0; w < WT_COUNT; ++w) { | |
| 204 if (w == GFS && !Is2D()) { | |
| 205 continue; | |
| 206 } | |
| 207 gate_weights_[w].ConvertToInt(); | |
| 208 } | |
| 209 if (softmax_ != nullptr) { | |
| 210 softmax_->ConvertToInt(); | |
| 211 } | |
| 212 } | |
| 213 | |
| 214 // Sets up the network for training using the given weight_range. | |
| 215 void LSTM::DebugWeights() { | |
| 216 for (int w = 0; w < WT_COUNT; ++w) { | |
| 217 if (w == GFS && !Is2D()) { | |
| 218 continue; | |
| 219 } | |
| 220 std::ostringstream msg; | |
| 221 msg << name_ << " Gate weights " << w; | |
| 222 gate_weights_[w].Debug2D(msg.str().c_str()); | |
| 223 } | |
| 224 if (softmax_ != nullptr) { | |
| 225 softmax_->DebugWeights(); | |
| 226 } | |
| 227 } | |
| 228 | |
| 229 // Writes to the given file. Returns false in case of error. | |
| 230 bool LSTM::Serialize(TFile *fp) const { | |
| 231 if (!Network::Serialize(fp)) { | |
| 232 return false; | |
| 233 } | |
| 234 if (!fp->Serialize(&na_)) { | |
| 235 return false; | |
| 236 } | |
| 237 for (int w = 0; w < WT_COUNT; ++w) { | |
| 238 if (w == GFS && !Is2D()) { | |
| 239 continue; | |
| 240 } | |
| 241 if (!gate_weights_[w].Serialize(IsTraining(), fp)) { | |
| 242 return false; | |
| 243 } | |
| 244 } | |
| 245 if (softmax_ != nullptr && !softmax_->Serialize(fp)) { | |
| 246 return false; | |
| 247 } | |
| 248 return true; | |
| 249 } | |
| 250 | |
| 251 // Reads from the given file. Returns false in case of error. | |
| 252 | |
| 253 bool LSTM::DeSerialize(TFile *fp) { | |
| 254 if (!fp->DeSerialize(&na_)) { | |
| 255 return false; | |
| 256 } | |
| 257 if (type_ == NT_LSTM_SOFTMAX) { | |
| 258 nf_ = no_; | |
| 259 } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) { | |
| 260 nf_ = ceil_log2(no_); | |
| 261 } else { | |
| 262 nf_ = 0; | |
| 263 } | |
| 264 is_2d_ = false; | |
| 265 for (int w = 0; w < WT_COUNT; ++w) { | |
| 266 if (w == GFS && !Is2D()) { | |
| 267 continue; | |
| 268 } | |
| 269 if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) { | |
| 270 return false; | |
| 271 } | |
| 272 if (w == CI) { | |
| 273 ns_ = gate_weights_[CI].NumOutputs(); | |
| 274 is_2d_ = na_ - nf_ == ni_ + 2 * ns_; | |
| 275 } | |
| 276 } | |
| 277 delete softmax_; | |
| 278 if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) { | |
| 279 softmax_ = static_cast<FullyConnected *>(Network::CreateFromFile(fp)); | |
| 280 if (softmax_ == nullptr) { | |
| 281 return false; | |
| 282 } | |
| 283 } else { | |
| 284 softmax_ = nullptr; | |
| 285 } | |
| 286 return true; | |
| 287 } | |
| 288 | |
| 289 // Runs forward propagation of activations on the input line. | |
| 290 // See NetworkCpp for a detailed discussion of the arguments. | |
| 291 void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, | |
| 292 NetworkScratch *scratch, NetworkIO *output) { | |
| 293 input_map_ = input.stride_map(); | |
| 294 input_width_ = input.Width(); | |
| 295 if (softmax_ != nullptr) { | |
| 296 output->ResizeFloat(input, no_); | |
| 297 } else if (type_ == NT_LSTM_SUMMARY) { | |
| 298 output->ResizeXTo1(input, no_); | |
| 299 } else { | |
| 300 output->Resize(input, no_); | |
| 301 } | |
| 302 ResizeForward(input); | |
| 303 // Temporary storage of forward computation for each gate. | |
| 304 NetworkScratch::FloatVec temp_lines[WT_COUNT]; | |
| 305 int ro = ns_; | |
| 306 if (source_.int_mode() && IntSimdMatrix::intSimdMatrix) { | |
| 307 ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro); | |
| 308 } | |
| 309 for (auto &temp_line : temp_lines) { | |
| 310 temp_line.Init(ns_, ro, scratch); | |
| 311 } | |
| 312 // Single timestep buffers for the current/recurrent output and state. | |
| 313 NetworkScratch::FloatVec curr_state, curr_output; | |
| 314 curr_state.Init(ns_, scratch); | |
| 315 ZeroVector<TFloat>(ns_, curr_state); | |
| 316 curr_output.Init(ns_, scratch); | |
| 317 ZeroVector<TFloat>(ns_, curr_output); | |
| 318 // Rotating buffers of width buf_width allow storage of the state and output | |
| 319 // for the other dimension, used only when working in true 2D mode. The width | |
| 320 // is enough to hold an entire strip of the major direction. | |
| 321 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1; | |
| 322 std::vector<NetworkScratch::FloatVec> states, outputs; | |
| 323 if (Is2D()) { | |
| 324 states.resize(buf_width); | |
| 325 outputs.resize(buf_width); | |
| 326 for (int i = 0; i < buf_width; ++i) { | |
| 327 states[i].Init(ns_, scratch); | |
| 328 ZeroVector<TFloat>(ns_, states[i]); | |
| 329 outputs[i].Init(ns_, scratch); | |
| 330 ZeroVector<TFloat>(ns_, outputs[i]); | |
| 331 } | |
| 332 } | |
| 333 // Used only if a softmax LSTM. | |
| 334 NetworkScratch::FloatVec softmax_output; | |
| 335 NetworkScratch::IO int_output; | |
| 336 if (softmax_ != nullptr) { | |
| 337 softmax_output.Init(no_, scratch); | |
| 338 ZeroVector<TFloat>(no_, softmax_output); | |
| 339 int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_); | |
| 340 if (input.int_mode()) { | |
| 341 int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch); | |
| 342 } | |
| 343 softmax_->SetupForward(input, nullptr); | |
| 344 } | |
| 345 NetworkScratch::FloatVec curr_input; | |
| 346 curr_input.Init(na_, scratch); | |
| 347 StrideMap::Index src_index(input_map_); | |
| 348 // Used only by NT_LSTM_SUMMARY. | |
| 349 StrideMap::Index dest_index(output->stride_map()); | |
| 350 do { | |
| 351 int t = src_index.t(); | |
| 352 // True if there is a valid old state for the 2nd dimension. | |
| 353 bool valid_2d = Is2D(); | |
| 354 if (valid_2d) { | |
| 355 StrideMap::Index dim_index(src_index); | |
| 356 if (!dim_index.AddOffset(-1, FD_HEIGHT)) { | |
| 357 valid_2d = false; | |
| 358 } | |
| 359 } | |
| 360 // Index of the 2-D revolving buffers (outputs, states). | |
| 361 int mod_t = Modulo(t, buf_width); // Current timestep. | |
| 362 // Setup the padded input in source. | |
| 363 source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0); | |
| 364 if (softmax_ != nullptr) { | |
| 365 source_.WriteTimeStepPart(t, ni_, nf_, softmax_output); | |
| 366 } | |
| 367 source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output); | |
| 368 if (Is2D()) { | |
| 369 source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]); | |
| 370 } | |
| 371 if (!source_.int_mode()) { | |
| 372 source_.ReadTimeStep(t, curr_input); | |
| 373 } | |
| 374 // Matrix multiply the inputs with the source. | |
| 375 PARALLEL_IF_OPENMP(GFS) | |
| 376 // It looks inefficient to create the threads on each t iteration, but the | |
| 377 // alternative of putting the parallel outside the t loop, a single around | |
| 378 // the t-loop and then tasks in place of the sections is a *lot* slower. | |
| 379 // Cell inputs. | |
| 380 if (source_.int_mode()) { | |
| 381 gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]); | |
| 382 } else { | |
| 383 gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]); | |
| 384 } | |
| 385 FuncInplace<GFunc>(ns_, temp_lines[CI]); | |
| 386 | |
| 387 SECTION_IF_OPENMP | |
| 388 // Input Gates. | |
| 389 if (source_.int_mode()) { | |
| 390 gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]); | |
| 391 } else { | |
| 392 gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]); | |
| 393 } | |
| 394 FuncInplace<FFunc>(ns_, temp_lines[GI]); | |
| 395 | |
| 396 SECTION_IF_OPENMP | |
| 397 // 1-D forget gates. | |
| 398 if (source_.int_mode()) { | |
| 399 gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]); | |
| 400 } else { | |
| 401 gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]); | |
| 402 } | |
| 403 FuncInplace<FFunc>(ns_, temp_lines[GF1]); | |
| 404 | |
| 405 // 2-D forget gates. | |
| 406 if (Is2D()) { | |
| 407 if (source_.int_mode()) { | |
| 408 gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]); | |
| 409 } else { | |
| 410 gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]); | |
| 411 } | |
| 412 FuncInplace<FFunc>(ns_, temp_lines[GFS]); | |
| 413 } | |
| 414 | |
| 415 SECTION_IF_OPENMP | |
| 416 // Output gates. | |
| 417 if (source_.int_mode()) { | |
| 418 gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]); | |
| 419 } else { | |
| 420 gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]); | |
| 421 } | |
| 422 FuncInplace<FFunc>(ns_, temp_lines[GO]); | |
| 423 END_PARALLEL_IF_OPENMP | |
| 424 | |
| 425 // Apply forget gate to state. | |
| 426 MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state); | |
| 427 if (Is2D()) { | |
| 428 // Max-pool the forget gates (in 2-d) instead of blindly adding. | |
| 429 int8_t *which_fg_col = which_fg_[t]; | |
| 430 memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0])); | |
| 431 if (valid_2d) { | |
| 432 const TFloat *stepped_state = states[mod_t]; | |
| 433 for (int i = 0; i < ns_; ++i) { | |
| 434 if (temp_lines[GF1][i] < temp_lines[GFS][i]) { | |
| 435 curr_state[i] = temp_lines[GFS][i] * stepped_state[i]; | |
| 436 which_fg_col[i] = 2; | |
| 437 } | |
| 438 } | |
| 439 } | |
| 440 } | |
| 441 MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state); | |
| 442 // Clip curr_state to a sane range. | |
| 443 ClipVector<TFloat>(ns_, -kStateClip, kStateClip, curr_state); | |
| 444 if (IsTraining()) { | |
| 445 // Save the gate node values. | |
| 446 node_values_[CI].WriteTimeStep(t, temp_lines[CI]); | |
| 447 node_values_[GI].WriteTimeStep(t, temp_lines[GI]); | |
| 448 node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]); | |
| 449 node_values_[GO].WriteTimeStep(t, temp_lines[GO]); | |
| 450 if (Is2D()) { | |
| 451 node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]); | |
| 452 } | |
| 453 } | |
| 454 FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output); | |
| 455 if (IsTraining()) { | |
| 456 state_.WriteTimeStep(t, curr_state); | |
| 457 } | |
| 458 if (softmax_ != nullptr) { | |
| 459 if (input.int_mode()) { | |
| 460 int_output->WriteTimeStepPart(0, 0, ns_, curr_output); | |
| 461 softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output); | |
| 462 } else { | |
| 463 softmax_->ForwardTimeStep(curr_output, t, softmax_output); | |
| 464 } | |
| 465 output->WriteTimeStep(t, softmax_output); | |
| 466 if (type_ == NT_LSTM_SOFTMAX_ENCODED) { | |
| 467 CodeInBinary(no_, nf_, softmax_output); | |
| 468 } | |
| 469 } else if (type_ == NT_LSTM_SUMMARY) { | |
| 470 // Output only at the end of a row. | |
| 471 if (src_index.IsLast(FD_WIDTH)) { | |
| 472 output->WriteTimeStep(dest_index.t(), curr_output); | |
| 473 dest_index.Increment(); | |
| 474 } | |
| 475 } else { | |
| 476 output->WriteTimeStep(t, curr_output); | |
| 477 } | |
| 478 // Save states for use by the 2nd dimension only if needed. | |
| 479 if (Is2D()) { | |
| 480 CopyVector(ns_, curr_state, states[mod_t]); | |
| 481 CopyVector(ns_, curr_output, outputs[mod_t]); | |
| 482 } | |
| 483 // Always zero the states at the end of every row, but only for the major | |
| 484 // direction. The 2-D state remains intact. | |
| 485 if (src_index.IsLast(FD_WIDTH)) { | |
| 486 ZeroVector<TFloat>(ns_, curr_state); | |
| 487 ZeroVector<TFloat>(ns_, curr_output); | |
| 488 } | |
| 489 } while (src_index.Increment()); | |
| 490 #if DEBUG_DETAIL > 0 | |
| 491 tprintf("Source:%s\n", name_.c_str()); | |
| 492 source_.Print(10); | |
| 493 tprintf("State:%s\n", name_.c_str()); | |
| 494 state_.Print(10); | |
| 495 tprintf("Output:%s\n", name_.c_str()); | |
| 496 output->Print(10); | |
| 497 #endif | |
| 498 #ifndef GRAPHICS_DISABLED | |
| 499 if (debug) { | |
| 500 DisplayForward(*output); | |
| 501 } | |
| 502 #endif | |
| 503 } | |
| 504 | |
| 505 // Runs backward propagation of errors on the deltas line. | |
| 506 // See NetworkCpp for a detailed discussion of the arguments. | |
| 507 bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, | |
| 508 NetworkIO *back_deltas) { | |
| 509 #ifndef GRAPHICS_DISABLED | |
| 510 if (debug) { | |
| 511 DisplayBackward(fwd_deltas); | |
| 512 } | |
| 513 #endif | |
| 514 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_); | |
| 515 // ======Scratch space.====== | |
| 516 // Output errors from deltas with recurrence from sourceerr. | |
| 517 NetworkScratch::FloatVec outputerr; | |
| 518 outputerr.Init(ns_, scratch); | |
| 519 // Recurrent error in the state/source. | |
| 520 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr; | |
| 521 curr_stateerr.Init(ns_, scratch); | |
| 522 curr_sourceerr.Init(na_, scratch); | |
| 523 ZeroVector<TFloat>(ns_, curr_stateerr); | |
| 524 ZeroVector<TFloat>(na_, curr_sourceerr); | |
| 525 // Errors in the gates. | |
| 526 NetworkScratch::FloatVec gate_errors[WT_COUNT]; | |
| 527 for (auto &gate_error : gate_errors) { | |
| 528 gate_error.Init(ns_, scratch); | |
| 529 } | |
| 530 // Rotating buffers of width buf_width allow storage of the recurrent time- | |
| 531 // steps used only for true 2-D. Stores one full strip of the major direction. | |
| 532 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1; | |
| 533 std::vector<NetworkScratch::FloatVec> stateerr, sourceerr; | |
| 534 if (Is2D()) { | |
| 535 stateerr.resize(buf_width); | |
| 536 sourceerr.resize(buf_width); | |
| 537 for (int t = 0; t < buf_width; ++t) { | |
| 538 stateerr[t].Init(ns_, scratch); | |
| 539 sourceerr[t].Init(na_, scratch); | |
| 540 ZeroVector<TFloat>(ns_, stateerr[t]); | |
| 541 ZeroVector<TFloat>(na_, sourceerr[t]); | |
| 542 } | |
| 543 } | |
| 544 // Parallel-generated sourceerr from each of the gates. | |
| 545 NetworkScratch::FloatVec sourceerr_temps[WT_COUNT]; | |
| 546 for (auto &sourceerr_temp : sourceerr_temps) { | |
| 547 sourceerr_temp.Init(na_, scratch); | |
| 548 } | |
| 549 int width = input_width_; | |
| 550 // Transposed gate errors stored over all timesteps for sum outer. | |
| 551 NetworkScratch::GradientStore gate_errors_t[WT_COUNT]; | |
| 552 for (auto &w : gate_errors_t) { | |
| 553 w.Init(ns_, width, scratch); | |
| 554 } | |
| 555 // Used only if softmax_ != nullptr. | |
| 556 NetworkScratch::FloatVec softmax_errors; | |
| 557 NetworkScratch::GradientStore softmax_errors_t; | |
| 558 if (softmax_ != nullptr) { | |
| 559 softmax_errors.Init(no_, scratch); | |
| 560 softmax_errors_t.Init(no_, width, scratch); | |
| 561 } | |
| 562 TFloat state_clip = Is2D() ? 9.0 : 4.0; | |
| 563 #if DEBUG_DETAIL > 1 | |
| 564 tprintf("fwd_deltas:%s\n", name_.c_str()); | |
| 565 fwd_deltas.Print(10); | |
| 566 #endif | |
| 567 StrideMap::Index dest_index(input_map_); | |
| 568 dest_index.InitToLast(); | |
| 569 // Used only by NT_LSTM_SUMMARY. | |
| 570 StrideMap::Index src_index(fwd_deltas.stride_map()); | |
| 571 src_index.InitToLast(); | |
| 572 do { | |
| 573 int t = dest_index.t(); | |
| 574 bool at_last_x = dest_index.IsLast(FD_WIDTH); | |
| 575 // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only | |
| 576 // valid if >= 0, which is true if 2d and not on the top/bottom. | |
| 577 int up_pos = -1; | |
| 578 int down_pos = -1; | |
| 579 if (Is2D()) { | |
| 580 if (dest_index.index(FD_HEIGHT) > 0) { | |
| 581 StrideMap::Index up_index(dest_index); | |
| 582 if (up_index.AddOffset(-1, FD_HEIGHT)) { | |
| 583 up_pos = up_index.t(); | |
| 584 } | |
| 585 } | |
| 586 if (!dest_index.IsLast(FD_HEIGHT)) { | |
| 587 StrideMap::Index down_index(dest_index); | |
| 588 if (down_index.AddOffset(1, FD_HEIGHT)) { | |
| 589 down_pos = down_index.t(); | |
| 590 } | |
| 591 } | |
| 592 } | |
| 593 // Index of the 2-D revolving buffers (sourceerr, stateerr). | |
| 594 int mod_t = Modulo(t, buf_width); // Current timestep. | |
| 595 // Zero the state in the major direction only at the end of every row. | |
| 596 if (at_last_x) { | |
| 597 ZeroVector<TFloat>(na_, curr_sourceerr); | |
| 598 ZeroVector<TFloat>(ns_, curr_stateerr); | |
| 599 } | |
| 600 // Setup the outputerr. | |
| 601 if (type_ == NT_LSTM_SUMMARY) { | |
| 602 if (dest_index.IsLast(FD_WIDTH)) { | |
| 603 fwd_deltas.ReadTimeStep(src_index.t(), outputerr); | |
| 604 src_index.Decrement(); | |
| 605 } else { | |
| 606 ZeroVector<TFloat>(ns_, outputerr); | |
| 607 } | |
| 608 } else if (softmax_ == nullptr) { | |
| 609 fwd_deltas.ReadTimeStep(t, outputerr); | |
| 610 } else { | |
| 611 softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr); | |
| 612 } | |
| 613 if (!at_last_x) { | |
| 614 AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr); | |
| 615 } | |
| 616 if (down_pos >= 0) { | |
| 617 AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr); | |
| 618 } | |
| 619 // Apply the 1-d forget gates. | |
| 620 if (!at_last_x) { | |
| 621 const float *next_node_gf1 = node_values_[GF1].f(t + 1); | |
| 622 for (int i = 0; i < ns_; ++i) { | |
| 623 curr_stateerr[i] *= next_node_gf1[i]; | |
| 624 } | |
| 625 } | |
| 626 if (Is2D() && t + 1 < width) { | |
| 627 for (int i = 0; i < ns_; ++i) { | |
| 628 if (which_fg_[t + 1][i] != 1) { | |
| 629 curr_stateerr[i] = 0.0; | |
| 630 } | |
| 631 } | |
| 632 if (down_pos >= 0) { | |
| 633 const float *right_node_gfs = node_values_[GFS].f(down_pos); | |
| 634 const TFloat *right_stateerr = stateerr[mod_t]; | |
| 635 for (int i = 0; i < ns_; ++i) { | |
| 636 if (which_fg_[down_pos][i] == 2) { | |
| 637 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i]; | |
| 638 } | |
| 639 } | |
| 640 } | |
| 641 } | |
| 642 state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr); | |
| 643 // Clip stateerr_ to a sane range. | |
| 644 ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr); | |
| 645 #if DEBUG_DETAIL > 1 | |
| 646 if (t + 10 > width) { | |
| 647 tprintf("t=%d, stateerr=", t); | |
| 648 for (int i = 0; i < ns_; ++i) | |
| 649 tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]); | |
| 650 tprintf("\n"); | |
| 651 } | |
| 652 #endif | |
| 653 // Matrix multiply to get the source errors. | |
| 654 PARALLEL_IF_OPENMP(GFS) | |
| 655 | |
| 656 // Cell inputs. | |
| 657 node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]); | |
| 658 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get()); | |
| 659 gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]); | |
| 660 gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]); | |
| 661 | |
| 662 SECTION_IF_OPENMP | |
| 663 // Input Gates. | |
| 664 node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]); | |
| 665 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get()); | |
| 666 gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]); | |
| 667 gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]); | |
| 668 | |
| 669 SECTION_IF_OPENMP | |
| 670 // 1-D forget Gates. | |
| 671 if (t > 0) { | |
| 672 node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]); | |
| 673 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get()); | |
| 674 gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]); | |
| 675 } else { | |
| 676 memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0])); | |
| 677 memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1])); | |
| 678 } | |
| 679 gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]); | |
| 680 | |
| 681 // 2-D forget Gates. | |
| 682 if (up_pos >= 0) { | |
| 683 node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]); | |
| 684 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get()); | |
| 685 gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]); | |
| 686 } else { | |
| 687 memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0])); | |
| 688 memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS])); | |
| 689 } | |
| 690 if (Is2D()) { | |
| 691 gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]); | |
| 692 } | |
| 693 | |
| 694 SECTION_IF_OPENMP | |
| 695 // Output gates. | |
| 696 state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]); | |
| 697 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get()); | |
| 698 gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]); | |
| 699 gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]); | |
| 700 END_PARALLEL_IF_OPENMP | |
| 701 | |
| 702 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1], | |
| 703 sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr); | |
| 704 back_deltas->WriteTimeStep(t, curr_sourceerr); | |
| 705 // Save states for use by the 2nd dimension only if needed. | |
| 706 if (Is2D()) { | |
| 707 CopyVector(ns_, curr_stateerr, stateerr[mod_t]); | |
| 708 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]); | |
| 709 } | |
| 710 } while (dest_index.Decrement()); | |
| 711 #if DEBUG_DETAIL > 2 | |
| 712 for (int w = 0; w < WT_COUNT; ++w) { | |
| 713 tprintf("%s gate errors[%d]\n", name_.c_str(), w); | |
| 714 gate_errors_t[w].get()->PrintUnTransposed(10); | |
| 715 } | |
| 716 #endif | |
| 717 // Transposed source_ used to speed-up SumOuter. | |
| 718 NetworkScratch::GradientStore source_t, state_t; | |
| 719 source_t.Init(na_, width, scratch); | |
| 720 source_.Transpose(source_t.get()); | |
| 721 state_t.Init(ns_, width, scratch); | |
| 722 state_.Transpose(state_t.get()); | |
| 723 #ifdef _OPENMP | |
| 724 # pragma omp parallel for num_threads(GFS) if (!Is2D()) | |
| 725 #endif | |
| 726 for (int w = 0; w < WT_COUNT; ++w) { | |
| 727 if (w == GFS && !Is2D()) { | |
| 728 continue; | |
| 729 } | |
| 730 gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false); | |
| 731 } | |
| 732 if (softmax_ != nullptr) { | |
| 733 softmax_->FinishBackward(*softmax_errors_t); | |
| 734 } | |
| 735 return needs_to_backprop_; | |
| 736 } | |
| 737 | |
| 738 // Updates the weights using the given learning rate, momentum and adam_beta. | |
| 739 // num_samples is used in the adam computation iff use_adam_ is true. | |
| 740 void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_samples) { | |
| 741 #if DEBUG_DETAIL > 3 | |
| 742 PrintW(); | |
| 743 #endif | |
| 744 for (int w = 0; w < WT_COUNT; ++w) { | |
| 745 if (w == GFS && !Is2D()) { | |
| 746 continue; | |
| 747 } | |
| 748 gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples); | |
| 749 } | |
| 750 if (softmax_ != nullptr) { | |
| 751 softmax_->Update(learning_rate, momentum, adam_beta, num_samples); | |
| 752 } | |
| 753 #if DEBUG_DETAIL > 3 | |
| 754 PrintDW(); | |
| 755 #endif | |
| 756 } | |
| 757 | |
| 758 // Sums the products of weight updates in *this and other, splitting into | |
| 759 // positive (same direction) in *same and negative (different direction) in | |
| 760 // *changed. | |
| 761 void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { | |
| 762 ASSERT_HOST(other.type() == type_); | |
| 763 const LSTM *lstm = static_cast<const LSTM *>(&other); | |
| 764 for (int w = 0; w < WT_COUNT; ++w) { | |
| 765 if (w == GFS && !Is2D()) { | |
| 766 continue; | |
| 767 } | |
| 768 gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed); | |
| 769 } | |
| 770 if (softmax_ != nullptr) { | |
| 771 softmax_->CountAlternators(*lstm->softmax_, same, changed); | |
| 772 } | |
| 773 } | |
| 774 | |
| 775 #if DEBUG_DETAIL > 3 | |
| 776 | |
| 777 // Prints the weights for debug purposes. | |
| 778 void LSTM::PrintW() { | |
| 779 tprintf("Weight state:%s\n", name_.c_str()); | |
| 780 for (int w = 0; w < WT_COUNT; ++w) { | |
| 781 if (w == GFS && !Is2D()) { | |
| 782 continue; | |
| 783 } | |
| 784 tprintf("Gate %d, inputs\n", w); | |
| 785 for (int i = 0; i < ni_; ++i) { | |
| 786 tprintf("Row %d:", i); | |
| 787 for (int s = 0; s < ns_; ++s) { | |
| 788 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]); | |
| 789 } | |
| 790 tprintf("\n"); | |
| 791 } | |
| 792 tprintf("Gate %d, outputs\n", w); | |
| 793 for (int i = ni_; i < ni_ + ns_; ++i) { | |
| 794 tprintf("Row %d:", i - ni_); | |
| 795 for (int s = 0; s < ns_; ++s) { | |
| 796 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]); | |
| 797 } | |
| 798 tprintf("\n"); | |
| 799 } | |
| 800 tprintf("Gate %d, bias\n", w); | |
| 801 for (int s = 0; s < ns_; ++s) { | |
| 802 tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]); | |
| 803 } | |
| 804 tprintf("\n"); | |
| 805 } | |
| 806 } | |
| 807 | |
| 808 // Prints the weight deltas for debug purposes. | |
| 809 void LSTM::PrintDW() { | |
| 810 tprintf("Delta state:%s\n", name_.c_str()); | |
| 811 for (int w = 0; w < WT_COUNT; ++w) { | |
| 812 if (w == GFS && !Is2D()) { | |
| 813 continue; | |
| 814 } | |
| 815 tprintf("Gate %d, inputs\n", w); | |
| 816 for (int i = 0; i < ni_; ++i) { | |
| 817 tprintf("Row %d:", i); | |
| 818 for (int s = 0; s < ns_; ++s) { | |
| 819 tprintf(" %g", gate_weights_[w].GetDW(s, i)); | |
| 820 } | |
| 821 tprintf("\n"); | |
| 822 } | |
| 823 tprintf("Gate %d, outputs\n", w); | |
| 824 for (int i = ni_; i < ni_ + ns_; ++i) { | |
| 825 tprintf("Row %d:", i - ni_); | |
| 826 for (int s = 0; s < ns_; ++s) { | |
| 827 tprintf(" %g", gate_weights_[w].GetDW(s, i)); | |
| 828 } | |
| 829 tprintf("\n"); | |
| 830 } | |
| 831 tprintf("Gate %d, bias\n", w); | |
| 832 for (int s = 0; s < ns_; ++s) { | |
| 833 tprintf(" %g", gate_weights_[w].GetDW(s, na_)); | |
| 834 } | |
| 835 tprintf("\n"); | |
| 836 } | |
| 837 } | |
| 838 | |
| 839 #endif | |
| 840 | |
| 841 // Resizes forward data to cope with an input image of the given width. | |
| 842 void LSTM::ResizeForward(const NetworkIO &input) { | |
| 843 int rounded_inputs = gate_weights_[CI].RoundInputs(na_); | |
| 844 source_.Resize(input, rounded_inputs); | |
| 845 which_fg_.ResizeNoInit(input.Width(), ns_); | |
| 846 if (IsTraining()) { | |
| 847 state_.ResizeFloat(input, ns_); | |
| 848 for (int w = 0; w < WT_COUNT; ++w) { | |
| 849 if (w == GFS && !Is2D()) { | |
| 850 continue; | |
| 851 } | |
| 852 node_values_[w].ResizeFloat(input, ns_); | |
| 853 } | |
| 854 } | |
| 855 } | |
| 856 | |
| 857 } // namespace tesseract. |
