diff mupdf-source/thirdparty/tesseract/src/lstm/lstm.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mupdf-source/thirdparty/tesseract/src/lstm/lstm.cpp	Mon Sep 15 11:43:07 2025 +0200
@@ -0,0 +1,857 @@
+///////////////////////////////////////////////////////////////////////
+// File:        lstm.cpp
+// Description: Long-term-short-term-memory Recurrent neural network.
+// Author:      Ray Smith
+//
+// (C) Copyright 2013, Google Inc.
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+///////////////////////////////////////////////////////////////////////
+
+#ifdef HAVE_CONFIG_H
+#  include "config_auto.h"
+#endif
+
+#include "lstm.h"
+
+#ifdef _OPENMP
+#  include <omp.h>
+#endif
+#include <cstdio>
+#include <cstdlib>
+#include <sstream> // for std::ostringstream
+
+#if defined(_MSC_VER) && !defined(__clang__)
+#  include <intrin.h> // _BitScanReverse
+#endif
+
+#include "fullyconnected.h"
+#include "functions.h"
+#include "networkscratch.h"
+#include "tprintf.h"
+
+// Macros for openmp code if it is available, otherwise empty macros.
+#ifdef _OPENMP
+#  define PARALLEL_IF_OPENMP(__num_threads)                                  \
+    PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
+      PRAGMA(omp sections nowait) {                                          \
+        PRAGMA(omp section) {
+#  define SECTION_IF_OPENMP \
+    }                       \
+    PRAGMA(omp section) {
+#  define END_PARALLEL_IF_OPENMP \
+    }                            \
+    } /* end of sections */      \
+    } /* end of parallel section */
+
+// Define the portable PRAGMA macro.
+#  ifdef _MSC_VER // Different _Pragma
+#    define PRAGMA(x) __pragma(x)
+#  else
+#    define PRAGMA(x) _Pragma(#    x)
+#  endif // _MSC_VER
+
+#else // _OPENMP
+#  define PARALLEL_IF_OPENMP(__num_threads)
+#  define SECTION_IF_OPENMP
+#  define END_PARALLEL_IF_OPENMP
+#endif // _OPENMP
+
+namespace tesseract {
+
+// Max absolute value of state_. It is reasonably high to enable the state
+// to count things.
+const TFloat kStateClip = 100.0;
+// Max absolute value of gate_errors (the gradients).
+const TFloat kErrClip = 1.0f;
+
+// Calculate ceil(log2(n)).
+static inline uint32_t ceil_log2(uint32_t n) {
+  // l2 = (unsigned)log2(n).
+#if defined(__GNUC__)
+  // Use fast inline assembler code for gcc or clang.
+  uint32_t l2 = 31 - __builtin_clz(n);
+#elif defined(_MSC_VER)
+  // Use fast intrinsic function for MS compiler.
+  unsigned long l2 = 0;
+  _BitScanReverse(&l2, n);
+#else
+  if (n == 0)
+    return UINT_MAX;
+  if (n == 1)
+    return 0;
+  uint32_t val = n;
+  uint32_t l2 = 0;
+  while (val > 1) {
+    val >>= 1;
+    l2++;
+  }
+#endif
+  // Round up if n is not a power of 2.
+  return (n == (1u << l2)) ? l2 : l2 + 1;
+}
+
+LSTM::LSTM(const std::string &name, int ni, int ns, int no, bool two_dimensional, NetworkType type)
+    : Network(type, name, ni, no)
+    , na_(ni + ns)
+    , ns_(ns)
+    , nf_(0)
+    , is_2d_(two_dimensional)
+    , softmax_(nullptr)
+    , input_width_(0) {
+  if (two_dimensional) {
+    na_ += ns_;
+  }
+  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
+    nf_ = 0;
+    // networkbuilder ensures this is always true.
+    ASSERT_HOST(no == ns);
+  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
+    nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
+    softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
+  } else {
+    tprintf("%d is invalid type of LSTM!\n", type);
+    ASSERT_HOST(false);
+  }
+  na_ += nf_;
+}
+
+LSTM::~LSTM() {
+  delete softmax_;
+}
+
+// Returns the shape output from the network given an input shape (which may
+// be partially unknown ie zero).
+StaticShape LSTM::OutputShape(const StaticShape &input_shape) const {
+  StaticShape result = input_shape;
+  result.set_depth(no_);
+  if (type_ == NT_LSTM_SUMMARY) {
+    result.set_width(1);
+  }
+  if (softmax_ != nullptr) {
+    return softmax_->OutputShape(result);
+  }
+  return result;
+}
+
+// Suspends/Enables training by setting the training_ flag. Serialize and
+// DeSerialize only operate on the run-time data if state is false.
+void LSTM::SetEnableTraining(TrainingState state) {
+  if (state == TS_RE_ENABLE) {
+    // Enable only from temp disabled.
+    if (training_ == TS_TEMP_DISABLE) {
+      training_ = TS_ENABLED;
+    }
+  } else if (state == TS_TEMP_DISABLE) {
+    // Temp disable only from enabled.
+    if (training_ == TS_ENABLED) {
+      training_ = state;
+    }
+  } else {
+    if (state == TS_ENABLED && training_ != TS_ENABLED) {
+      for (int w = 0; w < WT_COUNT; ++w) {
+        if (w == GFS && !Is2D()) {
+          continue;
+        }
+        gate_weights_[w].InitBackward();
+      }
+    }
+    training_ = state;
+  }
+  if (softmax_ != nullptr) {
+    softmax_->SetEnableTraining(state);
+  }
+}
+
+// Sets up the network for training. Initializes weights using weights of
+// scale `range` picked according to the random number generator `randomizer`.
+int LSTM::InitWeights(float range, TRand *randomizer) {
+  Network::SetRandomizer(randomizer);
+  num_weights_ = 0;
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    num_weights_ +=
+        gate_weights_[w].InitWeightsFloat(ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
+  }
+  if (softmax_ != nullptr) {
+    num_weights_ += softmax_->InitWeights(range, randomizer);
+  }
+  return num_weights_;
+}
+
+// Recursively searches the network for softmaxes with old_no outputs,
+// and remaps their outputs according to code_map. See network.h for details.
+int LSTM::RemapOutputs(int old_no, const std::vector<int> &code_map) {
+  if (softmax_ != nullptr) {
+    num_weights_ -= softmax_->num_weights();
+    num_weights_ += softmax_->RemapOutputs(old_no, code_map);
+  }
+  return num_weights_;
+}
+
+// Converts a float network to an int network.
+void LSTM::ConvertToInt() {
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    gate_weights_[w].ConvertToInt();
+  }
+  if (softmax_ != nullptr) {
+    softmax_->ConvertToInt();
+  }
+}
+
+// Sets up the network for training using the given weight_range.
+void LSTM::DebugWeights() {
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    std::ostringstream msg;
+    msg << name_ << " Gate weights " << w;
+    gate_weights_[w].Debug2D(msg.str().c_str());
+  }
+  if (softmax_ != nullptr) {
+    softmax_->DebugWeights();
+  }
+}
+
+// Writes to the given file. Returns false in case of error.
+bool LSTM::Serialize(TFile *fp) const {
+  if (!Network::Serialize(fp)) {
+    return false;
+  }
+  if (!fp->Serialize(&na_)) {
+    return false;
+  }
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    if (!gate_weights_[w].Serialize(IsTraining(), fp)) {
+      return false;
+    }
+  }
+  if (softmax_ != nullptr && !softmax_->Serialize(fp)) {
+    return false;
+  }
+  return true;
+}
+
+// Reads from the given file. Returns false in case of error.
+
+bool LSTM::DeSerialize(TFile *fp) {
+  if (!fp->DeSerialize(&na_)) {
+    return false;
+  }
+  if (type_ == NT_LSTM_SOFTMAX) {
+    nf_ = no_;
+  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
+    nf_ = ceil_log2(no_);
+  } else {
+    nf_ = 0;
+  }
+  is_2d_ = false;
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) {
+      return false;
+    }
+    if (w == CI) {
+      ns_ = gate_weights_[CI].NumOutputs();
+      is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
+    }
+  }
+  delete softmax_;
+  if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
+    softmax_ = static_cast<FullyConnected *>(Network::CreateFromFile(fp));
+    if (softmax_ == nullptr) {
+      return false;
+    }
+  } else {
+    softmax_ = nullptr;
+  }
+  return true;
+}
+
+// Runs forward propagation of activations on the input line.
+// See NetworkCpp for a detailed discussion of the arguments.
+void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
+                   NetworkScratch *scratch, NetworkIO *output) {
+  input_map_ = input.stride_map();
+  input_width_ = input.Width();
+  if (softmax_ != nullptr) {
+    output->ResizeFloat(input, no_);
+  } else if (type_ == NT_LSTM_SUMMARY) {
+    output->ResizeXTo1(input, no_);
+  } else {
+    output->Resize(input, no_);
+  }
+  ResizeForward(input);
+  // Temporary storage of forward computation for each gate.
+  NetworkScratch::FloatVec temp_lines[WT_COUNT];
+  int ro = ns_;
+  if (source_.int_mode() && IntSimdMatrix::intSimdMatrix) {
+    ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro);
+  }
+  for (auto &temp_line : temp_lines) {
+    temp_line.Init(ns_, ro, scratch);
+  }
+  // Single timestep buffers for the current/recurrent output and state.
+  NetworkScratch::FloatVec curr_state, curr_output;
+  curr_state.Init(ns_, scratch);
+  ZeroVector<TFloat>(ns_, curr_state);
+  curr_output.Init(ns_, scratch);
+  ZeroVector<TFloat>(ns_, curr_output);
+  // Rotating buffers of width buf_width allow storage of the state and output
+  // for the other dimension, used only when working in true 2D mode. The width
+  // is enough to hold an entire strip of the major direction.
+  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
+  std::vector<NetworkScratch::FloatVec> states, outputs;
+  if (Is2D()) {
+    states.resize(buf_width);
+    outputs.resize(buf_width);
+    for (int i = 0; i < buf_width; ++i) {
+      states[i].Init(ns_, scratch);
+      ZeroVector<TFloat>(ns_, states[i]);
+      outputs[i].Init(ns_, scratch);
+      ZeroVector<TFloat>(ns_, outputs[i]);
+    }
+  }
+  // Used only if a softmax LSTM.
+  NetworkScratch::FloatVec softmax_output;
+  NetworkScratch::IO int_output;
+  if (softmax_ != nullptr) {
+    softmax_output.Init(no_, scratch);
+    ZeroVector<TFloat>(no_, softmax_output);
+    int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
+    if (input.int_mode()) {
+      int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
+    }
+    softmax_->SetupForward(input, nullptr);
+  }
+  NetworkScratch::FloatVec curr_input;
+  curr_input.Init(na_, scratch);
+  StrideMap::Index src_index(input_map_);
+  // Used only by NT_LSTM_SUMMARY.
+  StrideMap::Index dest_index(output->stride_map());
+  do {
+    int t = src_index.t();
+    // True if there is a valid old state for the 2nd dimension.
+    bool valid_2d = Is2D();
+    if (valid_2d) {
+      StrideMap::Index dim_index(src_index);
+      if (!dim_index.AddOffset(-1, FD_HEIGHT)) {
+        valid_2d = false;
+      }
+    }
+    // Index of the 2-D revolving buffers (outputs, states).
+    int mod_t = Modulo(t, buf_width); // Current timestep.
+    // Setup the padded input in source.
+    source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
+    if (softmax_ != nullptr) {
+      source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
+    }
+    source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
+    if (Is2D()) {
+      source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
+    }
+    if (!source_.int_mode()) {
+      source_.ReadTimeStep(t, curr_input);
+    }
+    // Matrix multiply the inputs with the source.
+    PARALLEL_IF_OPENMP(GFS)
+    // It looks inefficient to create the threads on each t iteration, but the
+    // alternative of putting the parallel outside the t loop, a single around
+    // the t-loop and then tasks in place of the sections is a *lot* slower.
+    // Cell inputs.
+    if (source_.int_mode()) {
+      gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
+    } else {
+      gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
+    }
+    FuncInplace<GFunc>(ns_, temp_lines[CI]);
+
+    SECTION_IF_OPENMP
+    // Input Gates.
+    if (source_.int_mode()) {
+      gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
+    } else {
+      gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
+    }
+    FuncInplace<FFunc>(ns_, temp_lines[GI]);
+
+    SECTION_IF_OPENMP
+    // 1-D forget gates.
+    if (source_.int_mode()) {
+      gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
+    } else {
+      gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
+    }
+    FuncInplace<FFunc>(ns_, temp_lines[GF1]);
+
+    // 2-D forget gates.
+    if (Is2D()) {
+      if (source_.int_mode()) {
+        gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
+      } else {
+        gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
+      }
+      FuncInplace<FFunc>(ns_, temp_lines[GFS]);
+    }
+
+    SECTION_IF_OPENMP
+    // Output gates.
+    if (source_.int_mode()) {
+      gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
+    } else {
+      gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
+    }
+    FuncInplace<FFunc>(ns_, temp_lines[GO]);
+    END_PARALLEL_IF_OPENMP
+
+    // Apply forget gate to state.
+    MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
+    if (Is2D()) {
+      // Max-pool the forget gates (in 2-d) instead of blindly adding.
+      int8_t *which_fg_col = which_fg_[t];
+      memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
+      if (valid_2d) {
+        const TFloat *stepped_state = states[mod_t];
+        for (int i = 0; i < ns_; ++i) {
+          if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
+            curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
+            which_fg_col[i] = 2;
+          }
+        }
+      }
+    }
+    MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
+    // Clip curr_state to a sane range.
+    ClipVector<TFloat>(ns_, -kStateClip, kStateClip, curr_state);
+    if (IsTraining()) {
+      // Save the gate node values.
+      node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
+      node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
+      node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
+      node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
+      if (Is2D()) {
+        node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
+      }
+    }
+    FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
+    if (IsTraining()) {
+      state_.WriteTimeStep(t, curr_state);
+    }
+    if (softmax_ != nullptr) {
+      if (input.int_mode()) {
+        int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
+        softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
+      } else {
+        softmax_->ForwardTimeStep(curr_output, t, softmax_output);
+      }
+      output->WriteTimeStep(t, softmax_output);
+      if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
+        CodeInBinary(no_, nf_, softmax_output);
+      }
+    } else if (type_ == NT_LSTM_SUMMARY) {
+      // Output only at the end of a row.
+      if (src_index.IsLast(FD_WIDTH)) {
+        output->WriteTimeStep(dest_index.t(), curr_output);
+        dest_index.Increment();
+      }
+    } else {
+      output->WriteTimeStep(t, curr_output);
+    }
+    // Save states for use by the 2nd dimension only if needed.
+    if (Is2D()) {
+      CopyVector(ns_, curr_state, states[mod_t]);
+      CopyVector(ns_, curr_output, outputs[mod_t]);
+    }
+    // Always zero the states at the end of every row, but only for the major
+    // direction. The 2-D state remains intact.
+    if (src_index.IsLast(FD_WIDTH)) {
+      ZeroVector<TFloat>(ns_, curr_state);
+      ZeroVector<TFloat>(ns_, curr_output);
+    }
+  } while (src_index.Increment());
+#if DEBUG_DETAIL > 0
+  tprintf("Source:%s\n", name_.c_str());
+  source_.Print(10);
+  tprintf("State:%s\n", name_.c_str());
+  state_.Print(10);
+  tprintf("Output:%s\n", name_.c_str());
+  output->Print(10);
+#endif
+#ifndef GRAPHICS_DISABLED
+  if (debug) {
+    DisplayForward(*output);
+  }
+#endif
+}
+
+// Runs backward propagation of errors on the deltas line.
+// See NetworkCpp for a detailed discussion of the arguments.
+bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
+                    NetworkIO *back_deltas) {
+#ifndef GRAPHICS_DISABLED
+  if (debug) {
+    DisplayBackward(fwd_deltas);
+  }
+#endif
+  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
+  // ======Scratch space.======
+  // Output errors from deltas with recurrence from sourceerr.
+  NetworkScratch::FloatVec outputerr;
+  outputerr.Init(ns_, scratch);
+  // Recurrent error in the state/source.
+  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
+  curr_stateerr.Init(ns_, scratch);
+  curr_sourceerr.Init(na_, scratch);
+  ZeroVector<TFloat>(ns_, curr_stateerr);
+  ZeroVector<TFloat>(na_, curr_sourceerr);
+  // Errors in the gates.
+  NetworkScratch::FloatVec gate_errors[WT_COUNT];
+  for (auto &gate_error : gate_errors) {
+    gate_error.Init(ns_, scratch);
+  }
+  // Rotating buffers of width buf_width allow storage of the recurrent time-
+  // steps used only for true 2-D. Stores one full strip of the major direction.
+  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
+  std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
+  if (Is2D()) {
+    stateerr.resize(buf_width);
+    sourceerr.resize(buf_width);
+    for (int t = 0; t < buf_width; ++t) {
+      stateerr[t].Init(ns_, scratch);
+      sourceerr[t].Init(na_, scratch);
+      ZeroVector<TFloat>(ns_, stateerr[t]);
+      ZeroVector<TFloat>(na_, sourceerr[t]);
+    }
+  }
+  // Parallel-generated sourceerr from each of the gates.
+  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
+  for (auto &sourceerr_temp : sourceerr_temps) {
+    sourceerr_temp.Init(na_, scratch);
+  }
+  int width = input_width_;
+  // Transposed gate errors stored over all timesteps for sum outer.
+  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
+  for (auto &w : gate_errors_t) {
+    w.Init(ns_, width, scratch);
+  }
+  // Used only if softmax_ != nullptr.
+  NetworkScratch::FloatVec softmax_errors;
+  NetworkScratch::GradientStore softmax_errors_t;
+  if (softmax_ != nullptr) {
+    softmax_errors.Init(no_, scratch);
+    softmax_errors_t.Init(no_, width, scratch);
+  }
+  TFloat state_clip = Is2D() ? 9.0 : 4.0;
+#if DEBUG_DETAIL > 1
+  tprintf("fwd_deltas:%s\n", name_.c_str());
+  fwd_deltas.Print(10);
+#endif
+  StrideMap::Index dest_index(input_map_);
+  dest_index.InitToLast();
+  // Used only by NT_LSTM_SUMMARY.
+  StrideMap::Index src_index(fwd_deltas.stride_map());
+  src_index.InitToLast();
+  do {
+    int t = dest_index.t();
+    bool at_last_x = dest_index.IsLast(FD_WIDTH);
+    // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
+    // valid if >= 0, which is true if 2d and not on the top/bottom.
+    int up_pos = -1;
+    int down_pos = -1;
+    if (Is2D()) {
+      if (dest_index.index(FD_HEIGHT) > 0) {
+        StrideMap::Index up_index(dest_index);
+        if (up_index.AddOffset(-1, FD_HEIGHT)) {
+          up_pos = up_index.t();
+        }
+      }
+      if (!dest_index.IsLast(FD_HEIGHT)) {
+        StrideMap::Index down_index(dest_index);
+        if (down_index.AddOffset(1, FD_HEIGHT)) {
+          down_pos = down_index.t();
+        }
+      }
+    }
+    // Index of the 2-D revolving buffers (sourceerr, stateerr).
+    int mod_t = Modulo(t, buf_width); // Current timestep.
+    // Zero the state in the major direction only at the end of every row.
+    if (at_last_x) {
+      ZeroVector<TFloat>(na_, curr_sourceerr);
+      ZeroVector<TFloat>(ns_, curr_stateerr);
+    }
+    // Setup the outputerr.
+    if (type_ == NT_LSTM_SUMMARY) {
+      if (dest_index.IsLast(FD_WIDTH)) {
+        fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
+        src_index.Decrement();
+      } else {
+        ZeroVector<TFloat>(ns_, outputerr);
+      }
+    } else if (softmax_ == nullptr) {
+      fwd_deltas.ReadTimeStep(t, outputerr);
+    } else {
+      softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr);
+    }
+    if (!at_last_x) {
+      AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
+    }
+    if (down_pos >= 0) {
+      AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
+    }
+    // Apply the 1-d forget gates.
+    if (!at_last_x) {
+      const float *next_node_gf1 = node_values_[GF1].f(t + 1);
+      for (int i = 0; i < ns_; ++i) {
+        curr_stateerr[i] *= next_node_gf1[i];
+      }
+    }
+    if (Is2D() && t + 1 < width) {
+      for (int i = 0; i < ns_; ++i) {
+        if (which_fg_[t + 1][i] != 1) {
+          curr_stateerr[i] = 0.0;
+        }
+      }
+      if (down_pos >= 0) {
+        const float *right_node_gfs = node_values_[GFS].f(down_pos);
+        const TFloat *right_stateerr = stateerr[mod_t];
+        for (int i = 0; i < ns_; ++i) {
+          if (which_fg_[down_pos][i] == 2) {
+            curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
+          }
+        }
+      }
+    }
+    state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr);
+    // Clip stateerr_ to a sane range.
+    ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
+#if DEBUG_DETAIL > 1
+    if (t + 10 > width) {
+      tprintf("t=%d, stateerr=", t);
+      for (int i = 0; i < ns_; ++i)
+        tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]);
+      tprintf("\n");
+    }
+#endif
+    // Matrix multiply to get the source errors.
+    PARALLEL_IF_OPENMP(GFS)
+
+    // Cell inputs.
+    node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]);
+    ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
+    gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
+    gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
+
+    SECTION_IF_OPENMP
+    // Input Gates.
+    node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]);
+    ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
+    gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
+    gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
+
+    SECTION_IF_OPENMP
+    // 1-D forget Gates.
+    if (t > 0) {
+      node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]);
+      ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
+      gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]);
+    } else {
+      memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
+      memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
+    }
+    gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
+
+    // 2-D forget Gates.
+    if (up_pos >= 0) {
+      node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]);
+      ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
+      gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]);
+    } else {
+      memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
+      memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
+    }
+    if (Is2D()) {
+      gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
+    }
+
+    SECTION_IF_OPENMP
+    // Output gates.
+    state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]);
+    ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
+    gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
+    gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
+    END_PARALLEL_IF_OPENMP
+
+    SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1],
+               sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr);
+    back_deltas->WriteTimeStep(t, curr_sourceerr);
+    // Save states for use by the 2nd dimension only if needed.
+    if (Is2D()) {
+      CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
+      CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
+    }
+  } while (dest_index.Decrement());
+#if DEBUG_DETAIL > 2
+  for (int w = 0; w < WT_COUNT; ++w) {
+    tprintf("%s gate errors[%d]\n", name_.c_str(), w);
+    gate_errors_t[w].get()->PrintUnTransposed(10);
+  }
+#endif
+  // Transposed source_ used to speed-up SumOuter.
+  NetworkScratch::GradientStore source_t, state_t;
+  source_t.Init(na_, width, scratch);
+  source_.Transpose(source_t.get());
+  state_t.Init(ns_, width, scratch);
+  state_.Transpose(state_t.get());
+#ifdef _OPENMP
+#  pragma omp parallel for num_threads(GFS) if (!Is2D())
+#endif
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
+  }
+  if (softmax_ != nullptr) {
+    softmax_->FinishBackward(*softmax_errors_t);
+  }
+  return needs_to_backprop_;
+}
+
+// Updates the weights using the given learning rate, momentum and adam_beta.
+// num_samples is used in the adam computation iff use_adam_ is true.
+void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
+#if DEBUG_DETAIL > 3
+  PrintW();
+#endif
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
+  }
+  if (softmax_ != nullptr) {
+    softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
+  }
+#if DEBUG_DETAIL > 3
+  PrintDW();
+#endif
+}
+
+// Sums the products of weight updates in *this and other, splitting into
+// positive (same direction) in *same and negative (different direction) in
+// *changed.
+void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
+  ASSERT_HOST(other.type() == type_);
+  const LSTM *lstm = static_cast<const LSTM *>(&other);
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
+  }
+  if (softmax_ != nullptr) {
+    softmax_->CountAlternators(*lstm->softmax_, same, changed);
+  }
+}
+
+#if DEBUG_DETAIL > 3
+
+// Prints the weights for debug purposes.
+void LSTM::PrintW() {
+  tprintf("Weight state:%s\n", name_.c_str());
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    tprintf("Gate %d, inputs\n", w);
+    for (int i = 0; i < ni_; ++i) {
+      tprintf("Row %d:", i);
+      for (int s = 0; s < ns_; ++s) {
+        tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
+      }
+      tprintf("\n");
+    }
+    tprintf("Gate %d, outputs\n", w);
+    for (int i = ni_; i < ni_ + ns_; ++i) {
+      tprintf("Row %d:", i - ni_);
+      for (int s = 0; s < ns_; ++s) {
+        tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
+      }
+      tprintf("\n");
+    }
+    tprintf("Gate %d, bias\n", w);
+    for (int s = 0; s < ns_; ++s) {
+      tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
+    }
+    tprintf("\n");
+  }
+}
+
+// Prints the weight deltas for debug purposes.
+void LSTM::PrintDW() {
+  tprintf("Delta state:%s\n", name_.c_str());
+  for (int w = 0; w < WT_COUNT; ++w) {
+    if (w == GFS && !Is2D()) {
+      continue;
+    }
+    tprintf("Gate %d, inputs\n", w);
+    for (int i = 0; i < ni_; ++i) {
+      tprintf("Row %d:", i);
+      for (int s = 0; s < ns_; ++s) {
+        tprintf(" %g", gate_weights_[w].GetDW(s, i));
+      }
+      tprintf("\n");
+    }
+    tprintf("Gate %d, outputs\n", w);
+    for (int i = ni_; i < ni_ + ns_; ++i) {
+      tprintf("Row %d:", i - ni_);
+      for (int s = 0; s < ns_; ++s) {
+        tprintf(" %g", gate_weights_[w].GetDW(s, i));
+      }
+      tprintf("\n");
+    }
+    tprintf("Gate %d, bias\n", w);
+    for (int s = 0; s < ns_; ++s) {
+      tprintf(" %g", gate_weights_[w].GetDW(s, na_));
+    }
+    tprintf("\n");
+  }
+}
+
+#endif
+
+// Resizes forward data to cope with an input image of the given width.
+void LSTM::ResizeForward(const NetworkIO &input) {
+  int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
+  source_.Resize(input, rounded_inputs);
+  which_fg_.ResizeNoInit(input.Width(), ns_);
+  if (IsTraining()) {
+    state_.ResizeFloat(input, ns_);
+    for (int w = 0; w < WT_COUNT; ++w) {
+      if (w == GFS && !Is2D()) {
+        continue;
+      }
+      node_values_[w].ResizeFloat(input, ns_);
+    }
+  }
+}
+
+} // namespace tesseract.