comparison mupdf-source/thirdparty/tesseract/src/arch/intsimdmatrixavx2.cpp @ 2:b50eed0cc0ef upstream

ADD: MuPDF v1.26.7: the MuPDF source as downloaded by a default build of PyMuPDF 1.26.4. The directory name has changed: no version number in the expanded directory now.
author Franz Glasner <fzglas.hg@dom66.de>
date Mon, 15 Sep 2025 11:43:07 +0200
parents
children
comparison
equal deleted inserted replaced
1:1d09e1dec1d9 2:b50eed0cc0ef
1 ///////////////////////////////////////////////////////////////////////
2 // File: intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2017, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17
18 #include "intsimdmatrix.h"
19
20 #if !defined(__AVX2__)
21 # if defined(__i686__) || defined(__x86_64__)
22 # error Implementation only for AVX2 capable architectures
23 # endif
24 #else
25 # include <immintrin.h>
26 # include <algorithm>
27 # include <cstdint>
28 # include <vector>
29
30 # if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \
31 defined(_WIN32) && !defined(_WIN64)
32 // Optimize for size (/Os) instead of using the default optimization for some
33 // versions of the 32 bit Visual Studio compiler which generate buggy code.
34 # pragma optimize("", off)
35 # pragma optimize("s", on)
36 # endif
37
38 namespace tesseract {
39
40 // Number of outputs held in each register. 8 x 32 bit ints.
41 constexpr int kNumOutputsPerRegister = 8;
42 // Maximum number of registers that we will use.
43 constexpr int kMaxOutputRegisters = 8;
44 // Number of inputs in the inputs register.
45 constexpr int kNumInputsPerRegister = 32;
46 // Number of inputs in each weight group.
47 constexpr int kNumInputsPerGroup = 4;
48 // Number of groups of inputs to be broadcast.
49 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
50
51 // Functions to compute part of a matrix.vector multiplication. The weights
52 // are in a very specific order (see above) in w, which is multiplied by
53 // u of length num_in, to produce output v after scaling the integer results
54 // by the corresponding member of scales.
55 // The amount of w and scales consumed is fixed and not available to the
56 // caller. The number of outputs written to v will be at most num_out.
57
58 // Computes one set of 4x8 products of inputs and weights, adding to result.
59 // Horizontally adds 4 adjacent results, making 8x32-bit results.
60 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
61 // Note that wi must previously have been re-organized with blocks of 4x8
62 // weights in contiguous memory.
63 // ones is a register of 16x16-bit values all equal to 1.
64 // Note: wi is incremented by the amount of data read.
65 // weights and reps are scratch registers.
66 // This function must be inlined with references in order for the compiler to
67 // correctly use the registers declared in the caller.
68 static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
69 __m256i &weights, __m256i &reps, __m256i &result) {
70 // Load a 4x8 block of weights.
71 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
72 wi += kNumInputsPerRegister;
73 // Normalize the signs on rep_input, weights, so weights is always +ve.
74 reps = _mm256_sign_epi8(rep_input, weights);
75 weights = _mm256_sign_epi8(weights, weights);
76 // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
77 // with adjacent pairs added.
78 weights = _mm256_maddubs_epi16(weights, reps);
79 // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
80 // with adjacent pairs added. What we really want is a horizontal add of
81 // 16+16=32 bit result, but there is no such instruction, so multiply by
82 // 16-bit ones instead. It is probably faster than all the sign-extending,
83 // permuting and adding that would otherwise be required.
84 weights = _mm256_madd_epi16(weights, ones);
85 result = _mm256_add_epi32(result, weights);
86 }
87
88 // Load 64 bits into the bottom of a 128bit register.
89 // We don't actually care what the top 64bits are, but this ends
90 // up with them being zero.
91 static inline __m128i load64_to_128(const int8_t *wi_) {
92 const auto *wi = reinterpret_cast<const int64_t *>(wi_);
93 return _mm_set_epi64x(0, wi[0]);
94 }
95
96 #if defined(FAST_FLOAT)
97
98 static inline void ExtractResults8(__m256i result, const int8_t *wi,
99 const float *scales, float *v) {
100 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
101 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
102 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
103 __m256 scale01234567 = _mm256_loadu_ps(scales);
104 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
105 result = _mm256_add_epi32(result, w256); // result += bias * 127
106 __m256 res01234567 = _mm256_cvtepi32_ps(result);
107 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
108 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
109 _mm256_storeu_ps(v, res01234567);
110 }
111
112 static inline void ExtractResults16(__m256i result0, __m256i result1,
113 const int8_t *&wi, const float *&scales,
114 float *&v) {
115 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
116 // 8x8bit vals in bottom of 128bit reg
117 const __m256i bias_scale =
118 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
119 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
120 __m256 scale01234567 = _mm256_loadu_ps(scales);
121 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
122 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
123 __m256 res01234567 = _mm256_cvtepi32_ps(result0);
124 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
125 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
126 _mm256_storeu_ps(v, res01234567);
127 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
128 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
129 scale01234567 = _mm256_loadu_ps(scales + 8);
130 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
131 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
132 res01234567 = _mm256_cvtepi32_ps(result1);
133 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
134 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
135 _mm256_storeu_ps(v + 8, res01234567);
136 wi += 16;
137 scales += 16;
138 v += 16;
139 }
140
141 // Computes part of matrix.vector v = Wu. Computes N=64 results.
142 // The weights *must* be arranged so that consecutive reads from wi
143 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
144 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
145 // bias weights, before continuing with any more weights.
146 // u must be padded out with zeros to
147 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
148 static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
149 int num_in, float *v) {
150 // Register containing 16-bit ones for horizontal add with 16->32 bit
151 // conversion.
152 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
153 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
154 // Initialize all the results to 0.
155 __m256i result0 = _mm256_setzero_si256();
156 __m256i result1 = _mm256_setzero_si256();
157 __m256i result2 = _mm256_setzero_si256();
158 __m256i result3 = _mm256_setzero_si256();
159 __m256i result4 = _mm256_setzero_si256();
160 __m256i result5 = _mm256_setzero_si256();
161 __m256i result6 = _mm256_setzero_si256();
162 __m256i result7 = _mm256_setzero_si256();
163 // Iterate over the input (u), one registerful at a time.
164 for (int j = 0; j < num_in;) {
165 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
166 // Inputs are processed in groups of kNumInputsPerGroup, replicated
167 // kNumInputGroups times.
168 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
169 // Replicate the low 32 bits (4 inputs) 8 times.
170 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
171 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
172 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
173 __m256i weights, reps;
174 // Mul-add, with horizontal add of the 4 inputs to each of the results.
175 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
176 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
177 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
178 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
179 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
180 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
181 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
182 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
183 }
184 }
185 ExtractResults16(result0, result1, wi, scales, v);
186 ExtractResults16(result2, result3, wi, scales, v);
187 ExtractResults16(result4, result5, wi, scales, v);
188 ExtractResults16(result6, result7, wi, scales, v);
189 }
190
191 // Computes part of matrix.vector v = Wu. Computes N=32 results.
192 // For details see PartialMatrixDotVector64 with N=32.
193 static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
194 int num_in, float *v) {
195 // Register containing 16-bit ones for horizontal add with 16->32 bit
196 // conversion.
197 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
198 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
199 // Initialize all the results to 0.
200 __m256i result0 = _mm256_setzero_si256();
201 __m256i result1 = _mm256_setzero_si256();
202 __m256i result2 = _mm256_setzero_si256();
203 __m256i result3 = _mm256_setzero_si256();
204 // Iterate over the input (u), one registerful at a time.
205 for (int j = 0; j < num_in;) {
206 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
207 // Inputs are processed in groups of kNumInputsPerGroup, replicated
208 // kNumInputGroups times.
209 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
210 // Replicate the low 32 bits (4 inputs) 8 times.
211 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
212 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
213 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
214 __m256i weights, reps;
215 // Mul-add, with horizontal add of the 4 inputs to each of the results.
216 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
217 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
218 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
219 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
220 }
221 }
222 ExtractResults16(result0, result1, wi, scales, v);
223 ExtractResults16(result2, result3, wi, scales, v);
224 }
225
226 // Computes part of matrix.vector v = Wu. Computes N=16 results.
227 // For details see PartialMatrixDotVector64 with N=16.
228 static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
229 int num_in, float *v) {
230 // Register containing 16-bit ones for horizontal add with 16->32 bit
231 // conversion.
232 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
233 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
234 // Initialize all the results to 0.
235 __m256i result0 = _mm256_setzero_si256();
236 __m256i result1 = _mm256_setzero_si256();
237 // Iterate over the input (u), one registerful at a time.
238 for (int j = 0; j < num_in;) {
239 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
240 // Inputs are processed in groups of kNumInputsPerGroup, replicated
241 // kNumInputGroups times.
242 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
243 // Replicate the low 32 bits (4 inputs) 8 times.
244 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
245 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
246 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
247 __m256i weights, reps;
248 // Mul-add, with horizontal add of the 4 inputs to each of the results.
249 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
250 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
251 }
252 }
253 ExtractResults16(result0, result1, wi, scales, v);
254 }
255
256 // Computes part of matrix.vector v = Wu. Computes N=8 results.
257 // For details see PartialMatrixDotVector64 with N=8.
258 static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
259 int num_in, float *v) {
260 // Register containing 16-bit ones for horizontal add with 16->32 bit
261 // conversion.
262 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
263 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
264 // Initialize all the results to 0.
265 __m256i result0 = _mm256_setzero_si256();
266 // Iterate over the input (u), one registerful at a time.
267 for (int j = 0; j < num_in;) {
268 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
269 // Inputs are processed in groups of kNumInputsPerGroup, replicated
270 // kNumInputGroups times.
271 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
272 // Replicate the low 32 bits (4 inputs) 8 times.
273 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
274 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
275 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
276 __m256i weights, reps;
277 // Mul-add, with horizontal add of the 4 inputs to each of the results.
278 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
279 }
280 }
281 ExtractResults8(result0, wi, scales, v);
282 }
283
284 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
285 const int8_t *u, float *v) {
286 const int num_out = dim1;
287 const int num_in = dim2 - 1;
288 // Each call to a partial_func_ produces group_size outputs, except the
289 // last one, which can produce less.
290 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
291 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
292 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
293 int output = 0;
294
295 int w_step = (rounded_num_in + 1) * group_size;
296
297 // Run with this group size, until it would produce too much output, then
298 // switch to a smaller size.
299 for (; output + group_size <= rounded_num_out; output += group_size) {
300 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
301 wi += w_step;
302 scales += group_size;
303 v += group_size;
304 }
305 group_size /= 2;
306 w_step /= 2;
307
308 if (output + group_size <= rounded_num_out) {
309 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
310 wi += w_step;
311 scales += group_size;
312 v += group_size;
313 output += group_size;
314 }
315 group_size /= 2;
316 w_step /= 2;
317
318 if (output + group_size <= rounded_num_out) {
319 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
320 wi += w_step;
321 scales += group_size;
322 v += group_size;
323 output += group_size;
324 }
325 group_size /= 2;
326 w_step /= 2;
327
328 if (output + group_size <= rounded_num_out) {
329 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
330 }
331 }
332 #else
333 static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
334 double *v) {
335 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
336 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
337 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
338 __m256d scale0123 = _mm256_loadu_pd(scales);
339 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
340 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
341 result = _mm256_add_epi32(result, w256); // result += bias * 127
342 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
343 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
344 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
345 res0123 = _mm256_mul_pd(res0123, scale0123);
346 res4567 = _mm256_mul_pd(res4567, scale4567);
347 _mm256_storeu_pd(v, res0123);
348 _mm256_storeu_pd(v + 4, res4567);
349 }
350
351 static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
352 const double *&scales, double *&v) {
353 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
354 // 8x8bit vals in bottom of 128bit reg
355 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
356 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
357 __m256d scale0123 = _mm256_loadu_pd(scales);
358 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
359 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
360 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
361 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
362 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
363 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
364 res0123 = _mm256_mul_pd(res0123, scale0123);
365 res4567 = _mm256_mul_pd(res4567, scale4567);
366 _mm256_storeu_pd(v, res0123);
367 _mm256_storeu_pd(v + 4, res4567);
368 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
369 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
370 scale0123 = _mm256_loadu_pd(scales + 8);
371 scale4567 = _mm256_loadu_pd(scales + 12);
372 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
373 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
374 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
375 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
376 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
377 res0123 = _mm256_mul_pd(res0123, scale0123);
378 res4567 = _mm256_mul_pd(res4567, scale4567);
379 _mm256_storeu_pd(v + 8, res0123);
380 _mm256_storeu_pd(v + 12, res4567);
381 wi += 16;
382 scales += 16;
383 v += 16;
384 }
385
386 // Computes part of matrix.vector v = Wu. Computes N=64 results.
387 // The weights *must* be arranged so that consecutive reads from wi
388 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
389 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
390 // bias weights, before continuing with any more weights.
391 // u must be padded out with zeros to
392 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
393 static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
394 int num_in, double *v) {
395 // Register containing 16-bit ones for horizontal add with 16->32 bit
396 // conversion.
397 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
398 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
399 // Initialize all the results to 0.
400 __m256i result0 = _mm256_setzero_si256();
401 __m256i result1 = _mm256_setzero_si256();
402 __m256i result2 = _mm256_setzero_si256();
403 __m256i result3 = _mm256_setzero_si256();
404 __m256i result4 = _mm256_setzero_si256();
405 __m256i result5 = _mm256_setzero_si256();
406 __m256i result6 = _mm256_setzero_si256();
407 __m256i result7 = _mm256_setzero_si256();
408 // Iterate over the input (u), one registerful at a time.
409 for (int j = 0; j < num_in;) {
410 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
411 // Inputs are processed in groups of kNumInputsPerGroup, replicated
412 // kNumInputGroups times.
413 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
414 // Replicate the low 32 bits (4 inputs) 8 times.
415 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
416 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
417 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
418 __m256i weights, reps;
419 // Mul-add, with horizontal add of the 4 inputs to each of the results.
420 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
421 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
422 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
423 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
424 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
425 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
426 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
427 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
428 }
429 }
430 ExtractResults16(result0, result1, wi, scales, v);
431 ExtractResults16(result2, result3, wi, scales, v);
432 ExtractResults16(result4, result5, wi, scales, v);
433 ExtractResults16(result6, result7, wi, scales, v);
434 }
435
436 // Computes part of matrix.vector v = Wu. Computes N=32 results.
437 // For details see PartialMatrixDotVector64 with N=32.
438 static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
439 int num_in, double *v) {
440 // Register containing 16-bit ones for horizontal add with 16->32 bit
441 // conversion.
442 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
443 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
444 // Initialize all the results to 0.
445 __m256i result0 = _mm256_setzero_si256();
446 __m256i result1 = _mm256_setzero_si256();
447 __m256i result2 = _mm256_setzero_si256();
448 __m256i result3 = _mm256_setzero_si256();
449 // Iterate over the input (u), one registerful at a time.
450 for (int j = 0; j < num_in;) {
451 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
452 // Inputs are processed in groups of kNumInputsPerGroup, replicated
453 // kNumInputGroups times.
454 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
455 // Replicate the low 32 bits (4 inputs) 8 times.
456 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
457 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
458 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
459 __m256i weights, reps;
460 // Mul-add, with horizontal add of the 4 inputs to each of the results.
461 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
462 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
463 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
464 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
465 }
466 }
467 ExtractResults16(result0, result1, wi, scales, v);
468 ExtractResults16(result2, result3, wi, scales, v);
469 }
470
471 // Computes part of matrix.vector v = Wu. Computes N=16 results.
472 // For details see PartialMatrixDotVector64 with N=16.
473 static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
474 int num_in, double *v) {
475 // Register containing 16-bit ones for horizontal add with 16->32 bit
476 // conversion.
477 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
478 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
479 // Initialize all the results to 0.
480 __m256i result0 = _mm256_setzero_si256();
481 __m256i result1 = _mm256_setzero_si256();
482 // Iterate over the input (u), one registerful at a time.
483 for (int j = 0; j < num_in;) {
484 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
485 // Inputs are processed in groups of kNumInputsPerGroup, replicated
486 // kNumInputGroups times.
487 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
488 // Replicate the low 32 bits (4 inputs) 8 times.
489 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
490 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
491 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
492 __m256i weights, reps;
493 // Mul-add, with horizontal add of the 4 inputs to each of the results.
494 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
495 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
496 }
497 }
498 ExtractResults16(result0, result1, wi, scales, v);
499 }
500
501 // Computes part of matrix.vector v = Wu. Computes N=8 results.
502 // For details see PartialMatrixDotVector64 with N=8.
503 static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
504 int num_in, double *v) {
505 // Register containing 16-bit ones for horizontal add with 16->32 bit
506 // conversion.
507 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
508 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
509 // Initialize all the results to 0.
510 __m256i result0 = _mm256_setzero_si256();
511 // Iterate over the input (u), one registerful at a time.
512 for (int j = 0; j < num_in;) {
513 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
514 // Inputs are processed in groups of kNumInputsPerGroup, replicated
515 // kNumInputGroups times.
516 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
517 // Replicate the low 32 bits (4 inputs) 8 times.
518 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
519 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
520 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
521 __m256i weights, reps;
522 // Mul-add, with horizontal add of the 4 inputs to each of the results.
523 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
524 }
525 }
526 ExtractResults8(result0, wi, scales, v);
527 }
528
529 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
530 const int8_t *u, double *v) {
531 const int num_out = dim1;
532 const int num_in = dim2 - 1;
533 // Each call to a partial_func_ produces group_size outputs, except the
534 // last one, which can produce less.
535 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
536 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
537 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
538 int output = 0;
539
540 int w_step = (rounded_num_in + 1) * group_size;
541
542 // Run with this group size, until it would produce too much output, then
543 // switch to a smaller size.
544 for (; output + group_size <= rounded_num_out; output += group_size) {
545 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
546 wi += w_step;
547 scales += group_size;
548 v += group_size;
549 }
550 group_size /= 2;
551 w_step /= 2;
552
553 if (output + group_size <= rounded_num_out) {
554 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
555 wi += w_step;
556 scales += group_size;
557 v += group_size;
558 output += group_size;
559 }
560 group_size /= 2;
561 w_step /= 2;
562
563 if (output + group_size <= rounded_num_out) {
564 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
565 wi += w_step;
566 scales += group_size;
567 v += group_size;
568 output += group_size;
569 }
570 group_size /= 2;
571
572 if (output + group_size <= rounded_num_out) {
573 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
574 }
575 }
576 #endif
577
578 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
579 // Function.
580 matrixDotVector,
581 // Number of 32 bit outputs held in each register.
582 kNumOutputsPerRegister,
583 // Maximum number of registers that we will use to hold outputs.
584 kMaxOutputRegisters,
585 // Number of 8 bit inputs in the inputs register.
586 kNumInputsPerRegister,
587 // Number of inputs in each weight group.
588 kNumInputsPerGroup
589 };
590
591 } // namespace tesseract.
592
593 #endif