Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/lstm/fullyconnected.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: fullyconnected.cpp | |
| 3 // Description: Simple feed-forward layer with various non-linearities. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2014, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #ifdef HAVE_CONFIG_H | |
| 19 # include "config_auto.h" | |
| 20 #endif | |
| 21 | |
| 22 #include "fullyconnected.h" | |
| 23 | |
| 24 #ifdef _OPENMP | |
| 25 # include <omp.h> | |
| 26 #endif | |
| 27 #include <cstdio> | |
| 28 #include <cstdlib> | |
| 29 | |
| 30 #include "functions.h" | |
| 31 #include "networkscratch.h" | |
| 32 | |
| 33 // Number of threads to use for parallel calculation of Forward and Backward. | |
| 34 #ifdef _OPENMP | |
| 35 const int kNumThreads = 4; | |
| 36 #else | |
| 37 const int kNumThreads = 1; | |
| 38 #endif | |
| 39 | |
| 40 namespace tesseract { | |
| 41 | |
| 42 FullyConnected::FullyConnected(const std::string &name, int ni, int no, NetworkType type) | |
| 43 : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {} | |
| 44 | |
| 45 // Returns the shape output from the network given an input shape (which may | |
| 46 // be partially unknown ie zero). | |
| 47 StaticShape FullyConnected::OutputShape(const StaticShape &input_shape) const { | |
| 48 LossType loss_type = LT_NONE; | |
| 49 if (type_ == NT_SOFTMAX) { | |
| 50 loss_type = LT_CTC; | |
| 51 } else if (type_ == NT_SOFTMAX_NO_CTC) { | |
| 52 loss_type = LT_SOFTMAX; | |
| 53 } else if (type_ == NT_LOGISTIC) { | |
| 54 loss_type = LT_LOGISTIC; | |
| 55 } | |
| 56 StaticShape result(input_shape); | |
| 57 result.set_depth(no_); | |
| 58 result.set_loss_type(loss_type); | |
| 59 return result; | |
| 60 } | |
| 61 | |
| 62 // Suspends/Enables training by setting the training_ flag. | |
| 63 void FullyConnected::SetEnableTraining(TrainingState state) { | |
| 64 if (state == TS_RE_ENABLE) { | |
| 65 // Enable only from temp disabled. | |
| 66 if (training_ == TS_TEMP_DISABLE) { | |
| 67 training_ = TS_ENABLED; | |
| 68 } | |
| 69 } else if (state == TS_TEMP_DISABLE) { | |
| 70 // Temp disable only from enabled. | |
| 71 if (training_ == TS_ENABLED) { | |
| 72 training_ = state; | |
| 73 } | |
| 74 } else { | |
| 75 if (state == TS_ENABLED && training_ != TS_ENABLED) { | |
| 76 weights_.InitBackward(); | |
| 77 } | |
| 78 training_ = state; | |
| 79 } | |
| 80 } | |
| 81 | |
| 82 // Sets up the network for training. Initializes weights using weights of | |
| 83 // scale `range` picked according to the random number generator `randomizer`. | |
| 84 int FullyConnected::InitWeights(float range, TRand *randomizer) { | |
| 85 Network::SetRandomizer(randomizer); | |
| 86 num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM), range, randomizer); | |
| 87 return num_weights_; | |
| 88 } | |
| 89 | |
| 90 // Recursively searches the network for softmaxes with old_no outputs, | |
| 91 // and remaps their outputs according to code_map. See network.h for details. | |
| 92 | |
| 93 int FullyConnected::RemapOutputs(int old_no, const std::vector<int> &code_map) { | |
| 94 if (type_ == NT_SOFTMAX && no_ == old_no) { | |
| 95 num_weights_ = weights_.RemapOutputs(code_map); | |
| 96 no_ = code_map.size(); | |
| 97 } | |
| 98 return num_weights_; | |
| 99 } | |
| 100 | |
| 101 // Converts a float network to an int network. | |
| 102 void FullyConnected::ConvertToInt() { | |
| 103 weights_.ConvertToInt(); | |
| 104 } | |
| 105 | |
| 106 // Provides debug output on the weights. | |
| 107 void FullyConnected::DebugWeights() { | |
| 108 weights_.Debug2D(name_.c_str()); | |
| 109 } | |
| 110 | |
| 111 // Writes to the given file. Returns false in case of error. | |
| 112 bool FullyConnected::Serialize(TFile *fp) const { | |
| 113 if (!Network::Serialize(fp)) { | |
| 114 return false; | |
| 115 } | |
| 116 if (!weights_.Serialize(IsTraining(), fp)) { | |
| 117 return false; | |
| 118 } | |
| 119 return true; | |
| 120 } | |
| 121 | |
| 122 // Reads from the given file. Returns false in case of error. | |
| 123 bool FullyConnected::DeSerialize(TFile *fp) { | |
| 124 return weights_.DeSerialize(IsTraining(), fp); | |
| 125 } | |
| 126 | |
| 127 // Runs forward propagation of activations on the input line. | |
| 128 // See NetworkCpp for a detailed discussion of the arguments. | |
| 129 void FullyConnected::Forward(bool debug, const NetworkIO &input, | |
| 130 const TransposedArray *input_transpose, NetworkScratch *scratch, | |
| 131 NetworkIO *output) { | |
| 132 int width = input.Width(); | |
| 133 if (type_ == NT_SOFTMAX) { | |
| 134 output->ResizeFloat(input, no_); | |
| 135 } else { | |
| 136 output->Resize(input, no_); | |
| 137 } | |
| 138 SetupForward(input, input_transpose); | |
| 139 std::vector<NetworkScratch::FloatVec> temp_lines(kNumThreads); | |
| 140 std::vector<NetworkScratch::FloatVec> curr_input(kNumThreads); | |
| 141 int ro = no_; | |
| 142 if (IntSimdMatrix::intSimdMatrix) { | |
| 143 ro = IntSimdMatrix::intSimdMatrix->RoundOutputs(ro); | |
| 144 } | |
| 145 for (int i = 0; i < kNumThreads; ++i) { | |
| 146 temp_lines[i].Init(ro, scratch); | |
| 147 curr_input[i].Init(ni_, scratch); | |
| 148 } | |
| 149 #ifdef _OPENMP | |
| 150 # pragma omp parallel for num_threads(kNumThreads) | |
| 151 for (int t = 0; t < width; ++t) { | |
| 152 // Thread-local pointer to temporary storage. | |
| 153 int thread_id = omp_get_thread_num(); | |
| 154 #else | |
| 155 for (int t = 0; t < width; ++t) { | |
| 156 // Thread-local pointer to temporary storage. | |
| 157 int thread_id = 0; | |
| 158 #endif | |
| 159 TFloat *temp_line = temp_lines[thread_id]; | |
| 160 if (input.int_mode()) { | |
| 161 ForwardTimeStep(input.i(t), t, temp_line); | |
| 162 } else { | |
| 163 input.ReadTimeStep(t, curr_input[thread_id]); | |
| 164 ForwardTimeStep(curr_input[thread_id], t, temp_line); | |
| 165 } | |
| 166 output->WriteTimeStep(t, temp_line); | |
| 167 if (IsTraining() && type_ != NT_SOFTMAX) { | |
| 168 acts_.CopyTimeStepFrom(t, *output, t); | |
| 169 } | |
| 170 } | |
| 171 // Zero all the elements that are in the padding around images that allows | |
| 172 // multiple different-sized images to exist in a single array. | |
| 173 // acts_ is only used if this is not a softmax op. | |
| 174 if (IsTraining() && type_ != NT_SOFTMAX) { | |
| 175 acts_.ZeroInvalidElements(); | |
| 176 } | |
| 177 output->ZeroInvalidElements(); | |
| 178 #if DEBUG_DETAIL > 0 | |
| 179 tprintf("F Output:%s\n", name_.c_str()); | |
| 180 output->Print(10); | |
| 181 #endif | |
| 182 #ifndef GRAPHICS_DISABLED | |
| 183 if (debug) { | |
| 184 DisplayForward(*output); | |
| 185 } | |
| 186 #endif | |
| 187 } | |
| 188 | |
| 189 // Components of Forward so FullyConnected can be reused inside LSTM. | |
| 190 void FullyConnected::SetupForward(const NetworkIO &input, const TransposedArray *input_transpose) { | |
| 191 // Softmax output is always float, so save the input type. | |
| 192 int_mode_ = input.int_mode(); | |
| 193 if (IsTraining()) { | |
| 194 acts_.Resize(input, no_); | |
| 195 // Source_ is a transposed copy of input. It isn't needed if provided. | |
| 196 external_source_ = input_transpose; | |
| 197 if (external_source_ == nullptr) { | |
| 198 source_t_.ResizeNoInit(ni_, input.Width()); | |
| 199 } | |
| 200 } | |
| 201 } | |
| 202 | |
| 203 void FullyConnected::ForwardTimeStep(int t, TFloat *output_line) { | |
| 204 if (type_ == NT_TANH) { | |
| 205 FuncInplace<GFunc>(no_, output_line); | |
| 206 } else if (type_ == NT_LOGISTIC) { | |
| 207 FuncInplace<FFunc>(no_, output_line); | |
| 208 } else if (type_ == NT_POSCLIP) { | |
| 209 FuncInplace<ClipFFunc>(no_, output_line); | |
| 210 } else if (type_ == NT_SYMCLIP) { | |
| 211 FuncInplace<ClipGFunc>(no_, output_line); | |
| 212 } else if (type_ == NT_RELU) { | |
| 213 FuncInplace<Relu>(no_, output_line); | |
| 214 } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) { | |
| 215 SoftmaxInPlace(no_, output_line); | |
| 216 } else if (type_ != NT_LINEAR) { | |
| 217 ASSERT_HOST("Invalid fully-connected type!" == nullptr); | |
| 218 } | |
| 219 } | |
| 220 | |
| 221 void FullyConnected::ForwardTimeStep(const TFloat *d_input, int t, TFloat *output_line) { | |
| 222 // input is copied to source_ line-by-line for cache coherency. | |
| 223 if (IsTraining() && external_source_ == nullptr) { | |
| 224 source_t_.WriteStrided(t, d_input); | |
| 225 } | |
| 226 weights_.MatrixDotVector(d_input, output_line); | |
| 227 ForwardTimeStep(t, output_line); | |
| 228 } | |
| 229 | |
| 230 void FullyConnected::ForwardTimeStep(const int8_t *i_input, int t, TFloat *output_line) { | |
| 231 // input is copied to source_ line-by-line for cache coherency. | |
| 232 weights_.MatrixDotVector(i_input, output_line); | |
| 233 ForwardTimeStep(t, output_line); | |
| 234 } | |
| 235 | |
| 236 // Runs backward propagation of errors on the deltas line. | |
| 237 // See NetworkCpp for a detailed discussion of the arguments. | |
| 238 bool FullyConnected::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, | |
| 239 NetworkIO *back_deltas) { | |
| 240 #ifndef GRAPHICS_DISABLED | |
| 241 if (debug) { | |
| 242 DisplayBackward(fwd_deltas); | |
| 243 } | |
| 244 #endif | |
| 245 back_deltas->Resize(fwd_deltas, ni_); | |
| 246 std::vector<NetworkScratch::FloatVec> errors(kNumThreads); | |
| 247 for (int i = 0; i < kNumThreads; ++i) { | |
| 248 errors[i].Init(no_, scratch); | |
| 249 } | |
| 250 std::vector<NetworkScratch::FloatVec> temp_backprops; | |
| 251 if (needs_to_backprop_) { | |
| 252 temp_backprops.resize(kNumThreads); | |
| 253 for (int i = 0; i < kNumThreads; ++i) { | |
| 254 temp_backprops[i].Init(ni_, scratch); | |
| 255 } | |
| 256 } | |
| 257 int width = fwd_deltas.Width(); | |
| 258 NetworkScratch::GradientStore errors_t; | |
| 259 errors_t.Init(no_, width, scratch); | |
| 260 #ifdef _OPENMP | |
| 261 # pragma omp parallel for num_threads(kNumThreads) | |
| 262 for (int t = 0; t < width; ++t) { | |
| 263 int thread_id = omp_get_thread_num(); | |
| 264 #else | |
| 265 for (int t = 0; t < width; ++t) { | |
| 266 int thread_id = 0; | |
| 267 #endif | |
| 268 TFloat *backprop = nullptr; | |
| 269 if (needs_to_backprop_) { | |
| 270 backprop = temp_backprops[thread_id]; | |
| 271 } | |
| 272 TFloat *curr_errors = errors[thread_id]; | |
| 273 BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop); | |
| 274 if (backprop != nullptr) { | |
| 275 back_deltas->WriteTimeStep(t, backprop); | |
| 276 } | |
| 277 } | |
| 278 FinishBackward(*errors_t.get()); | |
| 279 if (needs_to_backprop_) { | |
| 280 back_deltas->ZeroInvalidElements(); | |
| 281 #if DEBUG_DETAIL > 0 | |
| 282 tprintf("F Backprop:%s\n", name_.c_str()); | |
| 283 back_deltas->Print(10); | |
| 284 #endif | |
| 285 return true; | |
| 286 } | |
| 287 return false; // No point going further back. | |
| 288 } | |
| 289 | |
| 290 void FullyConnected::BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, | |
| 291 TransposedArray *errors_t, TFloat *backprop) { | |
| 292 if (type_ == NT_TANH) { | |
| 293 acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors); | |
| 294 } else if (type_ == NT_LOGISTIC) { | |
| 295 acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors); | |
| 296 } else if (type_ == NT_POSCLIP) { | |
| 297 acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors); | |
| 298 } else if (type_ == NT_SYMCLIP) { | |
| 299 acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors); | |
| 300 } else if (type_ == NT_RELU) { | |
| 301 acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors); | |
| 302 } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || type_ == NT_LINEAR) { | |
| 303 fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors. | |
| 304 } else { | |
| 305 ASSERT_HOST("Invalid fully-connected type!" == nullptr); | |
| 306 } | |
| 307 // Generate backprop only if needed by the lower layer. | |
| 308 if (backprop != nullptr) { | |
| 309 weights_.VectorDotMatrix(curr_errors, backprop); | |
| 310 } | |
| 311 errors_t->WriteStrided(t, curr_errors); | |
| 312 } | |
| 313 | |
| 314 void FullyConnected::FinishBackward(const TransposedArray &errors_t) { | |
| 315 if (external_source_ == nullptr) { | |
| 316 weights_.SumOuterTransposed(errors_t, source_t_, true); | |
| 317 } else { | |
| 318 weights_.SumOuterTransposed(errors_t, *external_source_, true); | |
| 319 } | |
| 320 } | |
| 321 | |
| 322 // Updates the weights using the given learning rate, momentum and adam_beta. | |
| 323 // num_samples is used in the adam computation iff use_adam_ is true. | |
| 324 void FullyConnected::Update(float learning_rate, float momentum, float adam_beta, int num_samples) { | |
| 325 weights_.Update(learning_rate, momentum, adam_beta, num_samples); | |
| 326 } | |
| 327 | |
| 328 // Sums the products of weight updates in *this and other, splitting into | |
| 329 // positive (same direction) in *same and negative (different direction) in | |
| 330 // *changed. | |
| 331 void FullyConnected::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const { | |
| 332 ASSERT_HOST(other.type() == type_); | |
| 333 const auto *fc = static_cast<const FullyConnected *>(&other); | |
| 334 weights_.CountAlternators(fc->weights_, same, changed); | |
| 335 } | |
| 336 | |
| 337 } // namespace tesseract. |
