Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
 
Loading...
Searching...
No Matches
SVE/PacketMath.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020, Arm Limited and Contributors
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_PACKET_MATH_SVE_H
11#define EIGEN_PACKET_MATH_SVE_H
12
13namespace Eigen
14{
15namespace internal
16{
17#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
18#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
19#endif
20
21#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
22#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
23#endif
24
25#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
26
27template <typename Scalar, int SVEVectorLength>
28struct sve_packet_size_selector {
29 enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
30};
31
32/********************************* int32 **************************************/
33typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
34
35template <>
36struct packet_traits<numext::int32_t> : default_packet_traits {
37 typedef PacketXi type;
38 typedef PacketXi half; // Half not implemented yet
39 enum {
40 Vectorizable = 1,
41 AlignedOnScalar = 1,
42 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
43 HasHalfPacket = 0,
44
45 HasAdd = 1,
46 HasSub = 1,
47 HasShift = 1,
48 HasMul = 1,
49 HasNegate = 1,
50 HasAbs = 1,
51 HasArg = 0,
52 HasAbs2 = 1,
53 HasMin = 1,
54 HasMax = 1,
55 HasConj = 1,
56 HasSetLinear = 0,
57 HasBlend = 0,
58 HasReduxp = 0 // Not implemented in SVE
59 };
60};
61
62template <>
63struct unpacket_traits<PacketXi> {
64 typedef numext::int32_t type;
65 typedef PacketXi half; // Half not yet implemented
66 enum {
67 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
68 alignment = Aligned64,
69 vectorizable = true,
70 masked_load_available = false,
71 masked_store_available = false
72 };
73};
74
75template <>
76EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr)
77{
78 svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
79}
80
81template <>
82EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from)
83{
84 return svdup_n_s32(from);
85}
86
87template <>
88EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a)
89{
90 numext::int32_t c[packet_traits<numext::int32_t>::size];
91 for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
92 return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
93}
94
95template <>
96EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b)
97{
98 return svadd_s32_z(svptrue_b32(), a, b);
99}
100
101template <>
102EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b)
103{
104 return svsub_s32_z(svptrue_b32(), a, b);
105}
106
107template <>
108EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a)
109{
110 return svneg_s32_z(svptrue_b32(), a);
111}
112
113template <>
114EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a)
115{
116 return a;
117}
118
119template <>
120EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b)
121{
122 return svmul_s32_z(svptrue_b32(), a, b);
123}
124
125template <>
126EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b)
127{
128 return svdiv_s32_z(svptrue_b32(), a, b);
129}
130
131template <>
132EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c)
133{
134 return svmla_s32_z(svptrue_b32(), c, a, b);
135}
136
137template <>
138EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b)
139{
140 return svmin_s32_z(svptrue_b32(), a, b);
141}
142
143template <>
144EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b)
145{
146 return svmax_s32_z(svptrue_b32(), a, b);
147}
148
149template <>
150EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b)
151{
152 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
153}
154
155template <>
156EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b)
157{
158 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
159}
160
161template <>
162EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b)
163{
164 return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
165}
166
167template <>
168EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/)
169{
170 return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
171}
172
173template <>
174EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/)
175{
176 return svdup_n_s32_z(svptrue_b32(), 0);
177}
178
179template <>
180EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b)
181{
182 return svand_s32_z(svptrue_b32(), a, b);
183}
184
185template <>
186EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b)
187{
188 return svorr_s32_z(svptrue_b32(), a, b);
189}
190
191template <>
192EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b)
193{
194 return sveor_s32_z(svptrue_b32(), a, b);
195}
196
197template <>
198EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b)
199{
200 return svbic_s32_z(svptrue_b32(), a, b);
201}
202
203template <int N>
204EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a)
205{
206 return svasrd_n_s32_z(svptrue_b32(), a, N);
207}
208
209template <int N>
210EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a)
211{
212 return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N)));
213}
214
215template <int N>
216EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a)
217{
218 return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N));
219}
220
221template <>
222EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from)
223{
224 EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
225}
226
227template <>
228EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from)
229{
230 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
231}
232
233template <>
234EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from)
235{
236 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
237 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
238 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
239}
240
241template <>
242EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from)
243{
244 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
245 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
246 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
247 return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
248}
249
250template <>
251EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
252{
253 EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
254}
255
256template <>
257EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
258{
259 EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
260}
261
262template <>
263EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride)
264{
265 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
266 svint32_t indices = svindex_s32(0, stride);
267 return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
268}
269
270template <>
271EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride)
272{
273 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
274 svint32_t indices = svindex_s32(0, stride);
275 svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
276}
277
278template <>
279EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a)
280{
281 // svlasta returns the first element if all predicate bits are 0
282 return svlasta_s32(svpfalse_b(), a);
283}
284
285template <>
286EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a)
287{
288 return svrev_s32(a);
289}
290
291template <>
292EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a)
293{
294 return svabs_s32_z(svptrue_b32(), a);
295}
296
297template <>
298EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a)
299{
300 return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
301}
302
303template <>
304EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a)
305{
306 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
307 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
308
309 // Multiply the vector by its reverse
310 svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
311 svint32_t half_prod;
312
313 // Extract the high half of the vector. Depending on the VL more reductions need to be done
314 if (EIGEN_ARM64_SVE_VL >= 2048) {
315 half_prod = svtbl_s32(prod, svindex_u32(32, 1));
316 prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
317 }
318 if (EIGEN_ARM64_SVE_VL >= 1024) {
319 half_prod = svtbl_s32(prod, svindex_u32(16, 1));
320 prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
321 }
322 if (EIGEN_ARM64_SVE_VL >= 512) {
323 half_prod = svtbl_s32(prod, svindex_u32(8, 1));
324 prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
325 }
326 if (EIGEN_ARM64_SVE_VL >= 256) {
327 half_prod = svtbl_s32(prod, svindex_u32(4, 1));
328 prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
329 }
330 // Last reduction
331 half_prod = svtbl_s32(prod, svindex_u32(2, 1));
332 prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
333
334 // The reduction is done to the first element.
335 return pfirst<PacketXi>(prod);
336}
337
338template <>
339EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a)
340{
341 return svminv_s32(svptrue_b32(), a);
342}
343
344template <>
345EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a)
346{
347 return svmaxv_s32(svptrue_b32(), a);
348}
349
350template <int N>
351EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
352 int buffer[packet_traits<numext::int32_t>::size * N] = {0};
353 int i = 0;
354
355 PacketXi stride_index = svindex_s32(0, N);
356
357 for (i = 0; i < N; i++) {
358 svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
359 }
360 for (i = 0; i < N; i++) {
361 kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
362 }
363}
364
365/********************************* float32 ************************************/
366
367typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
368
369template <>
370struct packet_traits<float> : default_packet_traits {
371 typedef PacketXf type;
372 typedef PacketXf half;
373
374 enum {
375 Vectorizable = 1,
376 AlignedOnScalar = 1,
377 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
378 HasHalfPacket = 0,
379
380 HasAdd = 1,
381 HasSub = 1,
382 HasShift = 1,
383 HasMul = 1,
384 HasNegate = 1,
385 HasAbs = 1,
386 HasArg = 0,
387 HasAbs2 = 1,
388 HasMin = 1,
389 HasMax = 1,
390 HasConj = 1,
391 HasSetLinear = 0,
392 HasBlend = 0,
393 HasReduxp = 0, // Not implemented in SVE
394
395 HasDiv = 1,
396 HasFloor = 1,
397
398 HasSin = EIGEN_FAST_MATH,
399 HasCos = EIGEN_FAST_MATH,
400 HasLog = 1,
401 HasExp = 1,
402 HasSqrt = 0,
403 HasTanh = EIGEN_FAST_MATH,
404 HasErf = EIGEN_FAST_MATH
405 };
406};
407
408template <>
409struct unpacket_traits<PacketXf> {
410 typedef float type;
411 typedef PacketXf half; // Half not yet implemented
412 typedef PacketXi integer_packet;
413
414 enum {
415 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
416 alignment = Aligned64,
417 vectorizable = true,
418 masked_load_available = false,
419 masked_store_available = false
420 };
421};
422
423template <>
424EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from)
425{
426 return svdup_n_f32(from);
427}
428
429template <>
430EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from)
431{
432 return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
433}
434
435template <>
436EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a)
437{
438 float c[packet_traits<float>::size];
439 for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
440 return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
441}
442
443template <>
444EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b)
445{
446 return svadd_f32_z(svptrue_b32(), a, b);
447}
448
449template <>
450EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b)
451{
452 return svsub_f32_z(svptrue_b32(), a, b);
453}
454
455template <>
456EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a)
457{
458 return svneg_f32_z(svptrue_b32(), a);
459}
460
461template <>
462EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a)
463{
464 return a;
465}
466
467template <>
468EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b)
469{
470 return svmul_f32_z(svptrue_b32(), a, b);
471}
472
473template <>
474EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b)
475{
476 return svdiv_f32_z(svptrue_b32(), a, b);
477}
478
479template <>
480EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c)
481{
482 return svmla_f32_z(svptrue_b32(), c, a, b);
483}
484
485template <>
486EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b)
487{
488 return svmin_f32_z(svptrue_b32(), a, b);
489}
490
491template <>
492EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
493{
494 return pmin<PacketXf>(a, b);
495}
496
497template <>
498EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
499{
500 return svminnm_f32_z(svptrue_b32(), a, b);
501}
502
503template <>
504EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b)
505{
506 return svmax_f32_z(svptrue_b32(), a, b);
507}
508
509template <>
510EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
511{
512 return pmax<PacketXf>(a, b);
513}
514
515template <>
516EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
517{
518 return svmaxnm_f32_z(svptrue_b32(), a, b);
519}
520
521// Float comparisons in SVE return svbool (predicate). Use svdup to set active
522// lanes to 1 (0xffffffffu) and inactive lanes to 0.
523template <>
524EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b)
525{
526 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
527}
528
529template <>
530EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b)
531{
532 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
533}
534
535template <>
536EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b)
537{
538 return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
539}
540
541// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
542// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
543// active elements.
544template <>
545EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b)
546{
547 return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
548}
549
550template <>
551EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a)
552{
553 return svrintm_f32_z(svptrue_b32(), a);
554}
555
556template <>
557EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/)
558{
559 return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
560}
561
562// Logical Operations are not supported for float, so reinterpret casts
563template <>
564EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b)
565{
566 return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
567}
568
569template <>
570EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b)
571{
572 return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
573}
574
575template <>
576EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b)
577{
578 return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
579}
580
581template <>
582EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b)
583{
584 return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
585}
586
587template <>
588EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from)
589{
590 EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
591}
592
593template <>
594EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from)
595{
596 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
597}
598
599template <>
600EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from)
601{
602 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
603 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
604 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
605}
606
607template <>
608EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from)
609{
610 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...}
611 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...}
612 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
613 return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
614}
615
616template <>
617EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from)
618{
619 EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
620}
621
622template <>
623EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from)
624{
625 EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
626}
627
628template <>
629EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride)
630{
631 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
632 svint32_t indices = svindex_s32(0, stride);
633 return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
634}
635
636template <>
637EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride)
638{
639 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
640 svint32_t indices = svindex_s32(0, stride);
641 svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
642}
643
644template <>
645EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a)
646{
647 // svlasta returns the first element if all predicate bits are 0
648 return svlasta_f32(svpfalse_b(), a);
649}
650
651template <>
652EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a)
653{
654 return svrev_f32(a);
655}
656
657template <>
658EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
659{
660 return svabs_f32_z(svptrue_b32(), a);
661}
662
663// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
664// all vector extensions and the generic version.
665template <>
666EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
667{
668 return pfrexp_generic(a, exponent);
669}
670
671template <>
672EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a)
673{
674 return svaddv_f32(svptrue_b32(), a);
675}
676
677// Other reduction functions:
678// mul
679// Only works for SVE Vls multiple of 128
680template <>
681EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a)
682{
683 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
684 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
685 // Multiply the vector by its reverse
686 svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
687 svfloat32_t half_prod;
688
689 // Extract the high half of the vector. Depending on the VL more reductions need to be done
690 if (EIGEN_ARM64_SVE_VL >= 2048) {
691 half_prod = svtbl_f32(prod, svindex_u32(32, 1));
692 prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
693 }
694 if (EIGEN_ARM64_SVE_VL >= 1024) {
695 half_prod = svtbl_f32(prod, svindex_u32(16, 1));
696 prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
697 }
698 if (EIGEN_ARM64_SVE_VL >= 512) {
699 half_prod = svtbl_f32(prod, svindex_u32(8, 1));
700 prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
701 }
702 if (EIGEN_ARM64_SVE_VL >= 256) {
703 half_prod = svtbl_f32(prod, svindex_u32(4, 1));
704 prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
705 }
706 // Last reduction
707 half_prod = svtbl_f32(prod, svindex_u32(2, 1));
708 prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
709
710 // The reduction is done to the first element.
711 return pfirst<PacketXf>(prod);
712}
713
714template <>
715EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a)
716{
717 return svminv_f32(svptrue_b32(), a);
718}
719
720template <>
721EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a)
722{
723 return svmaxv_f32(svptrue_b32(), a);
724}
725
726template<int N>
727EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
728{
729 float buffer[packet_traits<float>::size * N] = {0};
730 int i = 0;
731
732 PacketXi stride_index = svindex_s32(0, N);
733
734 for (i = 0; i < N; i++) {
735 svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
736 }
737
738 for (i = 0; i < N; i++) {
739 kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
740 }
741}
742
743template<>
744EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
745{
746 return pldexp_generic(a, exponent);
747}
748
749} // namespace internal
750} // namespace Eigen
751
752#endif // EIGEN_PACKET_MATH_SVE_H
@ Aligned64
Definition: Constants.h:237
Namespace containing all symbols from the Eigen library.
Definition: Core:141
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74