Mercurial > hgrepos > Python2 > PyMuPDF
comparison mupdf-source/thirdparty/tesseract/src/arch/intsimdmatrixavx2.cpp @ 2:b50eed0cc0ef upstream
ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4.
The directory name has changed: no version number in the expanded directory now.
| author | Franz Glasner <fzglas.hg@dom66.de> |
|---|---|
| date | Mon, 15 Sep 2025 11:43:07 +0200 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:1d09e1dec1d9 | 2:b50eed0cc0ef |
|---|---|
| 1 /////////////////////////////////////////////////////////////////////// | |
| 2 // File: intsimdmatrixavx2.cpp | |
| 3 // Description: matrix-vector product for 8-bit data on avx2. | |
| 4 // Author: Ray Smith | |
| 5 // | |
| 6 // (C) Copyright 2017, Google Inc. | |
| 7 // Licensed under the Apache License, Version 2.0 (the "License"); | |
| 8 // you may not use this file except in compliance with the License. | |
| 9 // You may obtain a copy of the License at | |
| 10 // http://www.apache.org/licenses/LICENSE-2.0 | |
| 11 // Unless required by applicable law or agreed to in writing, software | |
| 12 // distributed under the License is distributed on an "AS IS" BASIS, | |
| 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 14 // See the License for the specific language governing permissions and | |
| 15 // limitations under the License. | |
| 16 /////////////////////////////////////////////////////////////////////// | |
| 17 | |
| 18 #include "intsimdmatrix.h" | |
| 19 | |
| 20 #if !defined(__AVX2__) | |
| 21 # if defined(__i686__) || defined(__x86_64__) | |
| 22 # error Implementation only for AVX2 capable architectures | |
| 23 # endif | |
| 24 #else | |
| 25 # include <immintrin.h> | |
| 26 # include <algorithm> | |
| 27 # include <cstdint> | |
| 28 # include <vector> | |
| 29 | |
| 30 # if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \ | |
| 31 defined(_WIN32) && !defined(_WIN64) | |
| 32 // Optimize for size (/Os) instead of using the default optimization for some | |
| 33 // versions of the 32 bit Visual Studio compiler which generate buggy code. | |
| 34 # pragma optimize("", off) | |
| 35 # pragma optimize("s", on) | |
| 36 # endif | |
| 37 | |
| 38 namespace tesseract { | |
| 39 | |
| 40 // Number of outputs held in each register. 8 x 32 bit ints. | |
| 41 constexpr int kNumOutputsPerRegister = 8; | |
| 42 // Maximum number of registers that we will use. | |
| 43 constexpr int kMaxOutputRegisters = 8; | |
| 44 // Number of inputs in the inputs register. | |
| 45 constexpr int kNumInputsPerRegister = 32; | |
| 46 // Number of inputs in each weight group. | |
| 47 constexpr int kNumInputsPerGroup = 4; | |
| 48 // Number of groups of inputs to be broadcast. | |
| 49 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup; | |
| 50 | |
| 51 // Functions to compute part of a matrix.vector multiplication. The weights | |
| 52 // are in a very specific order (see above) in w, which is multiplied by | |
| 53 // u of length num_in, to produce output v after scaling the integer results | |
| 54 // by the corresponding member of scales. | |
| 55 // The amount of w and scales consumed is fixed and not available to the | |
| 56 // caller. The number of outputs written to v will be at most num_out. | |
| 57 | |
| 58 // Computes one set of 4x8 products of inputs and weights, adding to result. | |
| 59 // Horizontally adds 4 adjacent results, making 8x32-bit results. | |
| 60 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers. | |
| 61 // Note that wi must previously have been re-organized with blocks of 4x8 | |
| 62 // weights in contiguous memory. | |
| 63 // ones is a register of 16x16-bit values all equal to 1. | |
| 64 // Note: wi is incremented by the amount of data read. | |
| 65 // weights and reps are scratch registers. | |
| 66 // This function must be inlined with references in order for the compiler to | |
| 67 // correctly use the registers declared in the caller. | |
| 68 static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi, | |
| 69 __m256i &weights, __m256i &reps, __m256i &result) { | |
| 70 // Load a 4x8 block of weights. | |
| 71 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi)); | |
| 72 wi += kNumInputsPerRegister; | |
| 73 // Normalize the signs on rep_input, weights, so weights is always +ve. | |
| 74 reps = _mm256_sign_epi8(rep_input, weights); | |
| 75 weights = _mm256_sign_epi8(weights, weights); | |
| 76 // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, | |
| 77 // with adjacent pairs added. | |
| 78 weights = _mm256_maddubs_epi16(weights, reps); | |
| 79 // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, | |
| 80 // with adjacent pairs added. What we really want is a horizontal add of | |
| 81 // 16+16=32 bit result, but there is no such instruction, so multiply by | |
| 82 // 16-bit ones instead. It is probably faster than all the sign-extending, | |
| 83 // permuting and adding that would otherwise be required. | |
| 84 weights = _mm256_madd_epi16(weights, ones); | |
| 85 result = _mm256_add_epi32(result, weights); | |
| 86 } | |
| 87 | |
| 88 // Load 64 bits into the bottom of a 128bit register. | |
| 89 // We don't actually care what the top 64bits are, but this ends | |
| 90 // up with them being zero. | |
| 91 static inline __m128i load64_to_128(const int8_t *wi_) { | |
| 92 const auto *wi = reinterpret_cast<const int64_t *>(wi_); | |
| 93 return _mm_set_epi64x(0, wi[0]); | |
| 94 } | |
| 95 | |
| 96 #if defined(FAST_FLOAT) | |
| 97 | |
| 98 static inline void ExtractResults8(__m256i result, const int8_t *wi, | |
| 99 const float *scales, float *v) { | |
| 100 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg | |
| 101 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg | |
| 102 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); | |
| 103 __m256 scale01234567 = _mm256_loadu_ps(scales); | |
| 104 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 105 result = _mm256_add_epi32(result, w256); // result += bias * 127 | |
| 106 __m256 res01234567 = _mm256_cvtepi32_ps(result); | |
| 107 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); | |
| 108 res01234567 = _mm256_mul_ps(res01234567, scale01234567); | |
| 109 _mm256_storeu_ps(v, res01234567); | |
| 110 } | |
| 111 | |
| 112 static inline void ExtractResults16(__m256i result0, __m256i result1, | |
| 113 const int8_t *&wi, const float *&scales, | |
| 114 float *&v) { | |
| 115 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi)); | |
| 116 // 8x8bit vals in bottom of 128bit reg | |
| 117 const __m256i bias_scale = | |
| 118 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); | |
| 119 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg | |
| 120 __m256 scale01234567 = _mm256_loadu_ps(scales); | |
| 121 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 122 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 | |
| 123 __m256 res01234567 = _mm256_cvtepi32_ps(result0); | |
| 124 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); | |
| 125 res01234567 = _mm256_mul_ps(res01234567, scale01234567); | |
| 126 _mm256_storeu_ps(v, res01234567); | |
| 127 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); | |
| 128 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg | |
| 129 scale01234567 = _mm256_loadu_ps(scales + 8); | |
| 130 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 131 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 | |
| 132 res01234567 = _mm256_cvtepi32_ps(result1); | |
| 133 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); | |
| 134 res01234567 = _mm256_mul_ps(res01234567, scale01234567); | |
| 135 _mm256_storeu_ps(v + 8, res01234567); | |
| 136 wi += 16; | |
| 137 scales += 16; | |
| 138 v += 16; | |
| 139 } | |
| 140 | |
| 141 // Computes part of matrix.vector v = Wu. Computes N=64 results. | |
| 142 // The weights *must* be arranged so that consecutive reads from wi | |
| 143 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of | |
| 144 // (kNumInputsPerGroup inputs))). After that there must be N consecutive | |
| 145 // bias weights, before continuing with any more weights. | |
| 146 // u must be padded out with zeros to | |
| 147 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. | |
| 148 static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, | |
| 149 int num_in, float *v) { | |
| 150 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 151 // conversion. | |
| 152 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 153 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 154 // Initialize all the results to 0. | |
| 155 __m256i result0 = _mm256_setzero_si256(); | |
| 156 __m256i result1 = _mm256_setzero_si256(); | |
| 157 __m256i result2 = _mm256_setzero_si256(); | |
| 158 __m256i result3 = _mm256_setzero_si256(); | |
| 159 __m256i result4 = _mm256_setzero_si256(); | |
| 160 __m256i result5 = _mm256_setzero_si256(); | |
| 161 __m256i result6 = _mm256_setzero_si256(); | |
| 162 __m256i result7 = _mm256_setzero_si256(); | |
| 163 // Iterate over the input (u), one registerful at a time. | |
| 164 for (int j = 0; j < num_in;) { | |
| 165 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 166 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 167 // kNumInputGroups times. | |
| 168 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 169 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 170 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 171 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 172 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 173 __m256i weights, reps; | |
| 174 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 175 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 176 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 177 MultiplyGroup(rep_input, ones, wi, weights, reps, result2); | |
| 178 MultiplyGroup(rep_input, ones, wi, weights, reps, result3); | |
| 179 MultiplyGroup(rep_input, ones, wi, weights, reps, result4); | |
| 180 MultiplyGroup(rep_input, ones, wi, weights, reps, result5); | |
| 181 MultiplyGroup(rep_input, ones, wi, weights, reps, result6); | |
| 182 MultiplyGroup(rep_input, ones, wi, weights, reps, result7); | |
| 183 } | |
| 184 } | |
| 185 ExtractResults16(result0, result1, wi, scales, v); | |
| 186 ExtractResults16(result2, result3, wi, scales, v); | |
| 187 ExtractResults16(result4, result5, wi, scales, v); | |
| 188 ExtractResults16(result6, result7, wi, scales, v); | |
| 189 } | |
| 190 | |
| 191 // Computes part of matrix.vector v = Wu. Computes N=32 results. | |
| 192 // For details see PartialMatrixDotVector64 with N=32. | |
| 193 static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, | |
| 194 int num_in, float *v) { | |
| 195 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 196 // conversion. | |
| 197 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 198 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 199 // Initialize all the results to 0. | |
| 200 __m256i result0 = _mm256_setzero_si256(); | |
| 201 __m256i result1 = _mm256_setzero_si256(); | |
| 202 __m256i result2 = _mm256_setzero_si256(); | |
| 203 __m256i result3 = _mm256_setzero_si256(); | |
| 204 // Iterate over the input (u), one registerful at a time. | |
| 205 for (int j = 0; j < num_in;) { | |
| 206 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 207 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 208 // kNumInputGroups times. | |
| 209 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 210 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 211 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 212 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 213 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 214 __m256i weights, reps; | |
| 215 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 216 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 217 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 218 MultiplyGroup(rep_input, ones, wi, weights, reps, result2); | |
| 219 MultiplyGroup(rep_input, ones, wi, weights, reps, result3); | |
| 220 } | |
| 221 } | |
| 222 ExtractResults16(result0, result1, wi, scales, v); | |
| 223 ExtractResults16(result2, result3, wi, scales, v); | |
| 224 } | |
| 225 | |
| 226 // Computes part of matrix.vector v = Wu. Computes N=16 results. | |
| 227 // For details see PartialMatrixDotVector64 with N=16. | |
| 228 static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, | |
| 229 int num_in, float *v) { | |
| 230 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 231 // conversion. | |
| 232 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 233 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 234 // Initialize all the results to 0. | |
| 235 __m256i result0 = _mm256_setzero_si256(); | |
| 236 __m256i result1 = _mm256_setzero_si256(); | |
| 237 // Iterate over the input (u), one registerful at a time. | |
| 238 for (int j = 0; j < num_in;) { | |
| 239 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 240 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 241 // kNumInputGroups times. | |
| 242 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 243 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 244 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 245 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 246 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 247 __m256i weights, reps; | |
| 248 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 249 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 250 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 251 } | |
| 252 } | |
| 253 ExtractResults16(result0, result1, wi, scales, v); | |
| 254 } | |
| 255 | |
| 256 // Computes part of matrix.vector v = Wu. Computes N=8 results. | |
| 257 // For details see PartialMatrixDotVector64 with N=8. | |
| 258 static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, | |
| 259 int num_in, float *v) { | |
| 260 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 261 // conversion. | |
| 262 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 263 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 264 // Initialize all the results to 0. | |
| 265 __m256i result0 = _mm256_setzero_si256(); | |
| 266 // Iterate over the input (u), one registerful at a time. | |
| 267 for (int j = 0; j < num_in;) { | |
| 268 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 269 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 270 // kNumInputGroups times. | |
| 271 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 272 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 273 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 274 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 275 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 276 __m256i weights, reps; | |
| 277 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 278 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 279 } | |
| 280 } | |
| 281 ExtractResults8(result0, wi, scales, v); | |
| 282 } | |
| 283 | |
| 284 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, | |
| 285 const int8_t *u, float *v) { | |
| 286 const int num_out = dim1; | |
| 287 const int num_in = dim2 - 1; | |
| 288 // Each call to a partial_func_ produces group_size outputs, except the | |
| 289 // last one, which can produce less. | |
| 290 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); | |
| 291 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); | |
| 292 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; | |
| 293 int output = 0; | |
| 294 | |
| 295 int w_step = (rounded_num_in + 1) * group_size; | |
| 296 | |
| 297 // Run with this group size, until it would produce too much output, then | |
| 298 // switch to a smaller size. | |
| 299 for (; output + group_size <= rounded_num_out; output += group_size) { | |
| 300 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); | |
| 301 wi += w_step; | |
| 302 scales += group_size; | |
| 303 v += group_size; | |
| 304 } | |
| 305 group_size /= 2; | |
| 306 w_step /= 2; | |
| 307 | |
| 308 if (output + group_size <= rounded_num_out) { | |
| 309 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); | |
| 310 wi += w_step; | |
| 311 scales += group_size; | |
| 312 v += group_size; | |
| 313 output += group_size; | |
| 314 } | |
| 315 group_size /= 2; | |
| 316 w_step /= 2; | |
| 317 | |
| 318 if (output + group_size <= rounded_num_out) { | |
| 319 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); | |
| 320 wi += w_step; | |
| 321 scales += group_size; | |
| 322 v += group_size; | |
| 323 output += group_size; | |
| 324 } | |
| 325 group_size /= 2; | |
| 326 w_step /= 2; | |
| 327 | |
| 328 if (output + group_size <= rounded_num_out) { | |
| 329 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); | |
| 330 } | |
| 331 } | |
| 332 #else | |
| 333 static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, | |
| 334 double *v) { | |
| 335 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg | |
| 336 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg | |
| 337 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); | |
| 338 __m256d scale0123 = _mm256_loadu_pd(scales); | |
| 339 __m256d scale4567 = _mm256_loadu_pd(scales + 4); | |
| 340 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 341 result = _mm256_add_epi32(result, w256); // result += bias * 127 | |
| 342 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); | |
| 343 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); | |
| 344 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); | |
| 345 res0123 = _mm256_mul_pd(res0123, scale0123); | |
| 346 res4567 = _mm256_mul_pd(res4567, scale4567); | |
| 347 _mm256_storeu_pd(v, res0123); | |
| 348 _mm256_storeu_pd(v + 4, res4567); | |
| 349 } | |
| 350 | |
| 351 static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, | |
| 352 const double *&scales, double *&v) { | |
| 353 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi)); | |
| 354 // 8x8bit vals in bottom of 128bit reg | |
| 355 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); | |
| 356 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg | |
| 357 __m256d scale0123 = _mm256_loadu_pd(scales); | |
| 358 __m256d scale4567 = _mm256_loadu_pd(scales + 4); | |
| 359 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 360 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 | |
| 361 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); | |
| 362 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); | |
| 363 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); | |
| 364 res0123 = _mm256_mul_pd(res0123, scale0123); | |
| 365 res4567 = _mm256_mul_pd(res4567, scale4567); | |
| 366 _mm256_storeu_pd(v, res0123); | |
| 367 _mm256_storeu_pd(v + 4, res4567); | |
| 368 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); | |
| 369 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg | |
| 370 scale0123 = _mm256_loadu_pd(scales + 8); | |
| 371 scale4567 = _mm256_loadu_pd(scales + 12); | |
| 372 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127> | |
| 373 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 | |
| 374 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); | |
| 375 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); | |
| 376 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); | |
| 377 res0123 = _mm256_mul_pd(res0123, scale0123); | |
| 378 res4567 = _mm256_mul_pd(res4567, scale4567); | |
| 379 _mm256_storeu_pd(v + 8, res0123); | |
| 380 _mm256_storeu_pd(v + 12, res4567); | |
| 381 wi += 16; | |
| 382 scales += 16; | |
| 383 v += 16; | |
| 384 } | |
| 385 | |
| 386 // Computes part of matrix.vector v = Wu. Computes N=64 results. | |
| 387 // The weights *must* be arranged so that consecutive reads from wi | |
| 388 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of | |
| 389 // (kNumInputsPerGroup inputs))). After that there must be N consecutive | |
| 390 // bias weights, before continuing with any more weights. | |
| 391 // u must be padded out with zeros to | |
| 392 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. | |
| 393 static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u, | |
| 394 int num_in, double *v) { | |
| 395 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 396 // conversion. | |
| 397 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 398 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 399 // Initialize all the results to 0. | |
| 400 __m256i result0 = _mm256_setzero_si256(); | |
| 401 __m256i result1 = _mm256_setzero_si256(); | |
| 402 __m256i result2 = _mm256_setzero_si256(); | |
| 403 __m256i result3 = _mm256_setzero_si256(); | |
| 404 __m256i result4 = _mm256_setzero_si256(); | |
| 405 __m256i result5 = _mm256_setzero_si256(); | |
| 406 __m256i result6 = _mm256_setzero_si256(); | |
| 407 __m256i result7 = _mm256_setzero_si256(); | |
| 408 // Iterate over the input (u), one registerful at a time. | |
| 409 for (int j = 0; j < num_in;) { | |
| 410 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 411 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 412 // kNumInputGroups times. | |
| 413 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 414 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 415 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 416 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 417 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 418 __m256i weights, reps; | |
| 419 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 420 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 421 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 422 MultiplyGroup(rep_input, ones, wi, weights, reps, result2); | |
| 423 MultiplyGroup(rep_input, ones, wi, weights, reps, result3); | |
| 424 MultiplyGroup(rep_input, ones, wi, weights, reps, result4); | |
| 425 MultiplyGroup(rep_input, ones, wi, weights, reps, result5); | |
| 426 MultiplyGroup(rep_input, ones, wi, weights, reps, result6); | |
| 427 MultiplyGroup(rep_input, ones, wi, weights, reps, result7); | |
| 428 } | |
| 429 } | |
| 430 ExtractResults16(result0, result1, wi, scales, v); | |
| 431 ExtractResults16(result2, result3, wi, scales, v); | |
| 432 ExtractResults16(result4, result5, wi, scales, v); | |
| 433 ExtractResults16(result6, result7, wi, scales, v); | |
| 434 } | |
| 435 | |
| 436 // Computes part of matrix.vector v = Wu. Computes N=32 results. | |
| 437 // For details see PartialMatrixDotVector64 with N=32. | |
| 438 static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u, | |
| 439 int num_in, double *v) { | |
| 440 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 441 // conversion. | |
| 442 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 443 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 444 // Initialize all the results to 0. | |
| 445 __m256i result0 = _mm256_setzero_si256(); | |
| 446 __m256i result1 = _mm256_setzero_si256(); | |
| 447 __m256i result2 = _mm256_setzero_si256(); | |
| 448 __m256i result3 = _mm256_setzero_si256(); | |
| 449 // Iterate over the input (u), one registerful at a time. | |
| 450 for (int j = 0; j < num_in;) { | |
| 451 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 452 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 453 // kNumInputGroups times. | |
| 454 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 455 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 456 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 457 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 458 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 459 __m256i weights, reps; | |
| 460 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 461 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 462 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 463 MultiplyGroup(rep_input, ones, wi, weights, reps, result2); | |
| 464 MultiplyGroup(rep_input, ones, wi, weights, reps, result3); | |
| 465 } | |
| 466 } | |
| 467 ExtractResults16(result0, result1, wi, scales, v); | |
| 468 ExtractResults16(result2, result3, wi, scales, v); | |
| 469 } | |
| 470 | |
| 471 // Computes part of matrix.vector v = Wu. Computes N=16 results. | |
| 472 // For details see PartialMatrixDotVector64 with N=16. | |
| 473 static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u, | |
| 474 int num_in, double *v) { | |
| 475 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 476 // conversion. | |
| 477 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 478 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 479 // Initialize all the results to 0. | |
| 480 __m256i result0 = _mm256_setzero_si256(); | |
| 481 __m256i result1 = _mm256_setzero_si256(); | |
| 482 // Iterate over the input (u), one registerful at a time. | |
| 483 for (int j = 0; j < num_in;) { | |
| 484 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 485 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 486 // kNumInputGroups times. | |
| 487 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 488 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 489 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 490 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 491 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 492 __m256i weights, reps; | |
| 493 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 494 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 495 MultiplyGroup(rep_input, ones, wi, weights, reps, result1); | |
| 496 } | |
| 497 } | |
| 498 ExtractResults16(result0, result1, wi, scales, v); | |
| 499 } | |
| 500 | |
| 501 // Computes part of matrix.vector v = Wu. Computes N=8 results. | |
| 502 // For details see PartialMatrixDotVector64 with N=8. | |
| 503 static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u, | |
| 504 int num_in, double *v) { | |
| 505 // Register containing 16-bit ones for horizontal add with 16->32 bit | |
| 506 // conversion. | |
| 507 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
| 508 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); | |
| 509 // Initialize all the results to 0. | |
| 510 __m256i result0 = _mm256_setzero_si256(); | |
| 511 // Iterate over the input (u), one registerful at a time. | |
| 512 for (int j = 0; j < num_in;) { | |
| 513 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j)); | |
| 514 // Inputs are processed in groups of kNumInputsPerGroup, replicated | |
| 515 // kNumInputGroups times. | |
| 516 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { | |
| 517 // Replicate the low 32 bits (4 inputs) 8 times. | |
| 518 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); | |
| 519 // Rotate the inputs in groups of 4, so the next 4 inputs are ready. | |
| 520 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); | |
| 521 __m256i weights, reps; | |
| 522 // Mul-add, with horizontal add of the 4 inputs to each of the results. | |
| 523 MultiplyGroup(rep_input, ones, wi, weights, reps, result0); | |
| 524 } | |
| 525 } | |
| 526 ExtractResults8(result0, wi, scales, v); | |
| 527 } | |
| 528 | |
| 529 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales, | |
| 530 const int8_t *u, double *v) { | |
| 531 const int num_out = dim1; | |
| 532 const int num_in = dim2 - 1; | |
| 533 // Each call to a partial_func_ produces group_size outputs, except the | |
| 534 // last one, which can produce less. | |
| 535 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); | |
| 536 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); | |
| 537 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; | |
| 538 int output = 0; | |
| 539 | |
| 540 int w_step = (rounded_num_in + 1) * group_size; | |
| 541 | |
| 542 // Run with this group size, until it would produce too much output, then | |
| 543 // switch to a smaller size. | |
| 544 for (; output + group_size <= rounded_num_out; output += group_size) { | |
| 545 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); | |
| 546 wi += w_step; | |
| 547 scales += group_size; | |
| 548 v += group_size; | |
| 549 } | |
| 550 group_size /= 2; | |
| 551 w_step /= 2; | |
| 552 | |
| 553 if (output + group_size <= rounded_num_out) { | |
| 554 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); | |
| 555 wi += w_step; | |
| 556 scales += group_size; | |
| 557 v += group_size; | |
| 558 output += group_size; | |
| 559 } | |
| 560 group_size /= 2; | |
| 561 w_step /= 2; | |
| 562 | |
| 563 if (output + group_size <= rounded_num_out) { | |
| 564 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); | |
| 565 wi += w_step; | |
| 566 scales += group_size; | |
| 567 v += group_size; | |
| 568 output += group_size; | |
| 569 } | |
| 570 group_size /= 2; | |
| 571 | |
| 572 if (output + group_size <= rounded_num_out) { | |
| 573 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); | |
| 574 } | |
| 575 } | |
| 576 #endif | |
| 577 | |
| 578 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = { | |
| 579 // Function. | |
| 580 matrixDotVector, | |
| 581 // Number of 32 bit outputs held in each register. | |
| 582 kNumOutputsPerRegister, | |
| 583 // Maximum number of registers that we will use to hold outputs. | |
| 584 kMaxOutputRegisters, | |
| 585 // Number of 8 bit inputs in the inputs register. | |
| 586 kNumInputsPerRegister, | |
| 587 // Number of inputs in each weight group. | |
| 588 kNumInputsPerGroup | |
| 589 }; | |
| 590 | |
| 591 } // namespace tesseract. | |
| 592 | |
| 593 #endif |
