Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorContractionSycl.h
1// This file is part of Eigen, a lightweight C++ template library for linear algebra.
2//
3// Mehdi Goli Codeplay Software Ltd.
4// Ralph Potter Codeplay Software Ltd.
5// Luke Iwanski Codeplay Software Ltd.
6// Contact: <eigen@codeplay.com>
7//
8// This Source Code Form is subject to the terms of the Mozilla Public License v. 2.0. If a copy of the MPL was not
9// distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11/*****************************************************************
12 * TensorContractionSycl.h
13 *
14 * \brief:
15 * TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
16 *
17 *****************************************************************/
18
19#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
21
22namespace Eigen {
23
24namespace TensorSycl {
25namespace internal {
26
27#ifndef EIGEN_SYCL_DISABLE_GEMV
42template <typename Scalar, typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
43struct TVPanelSize {
44 // LocalThreadSizeC: determines total number of thread per workgroup for the contracting dimension
45 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0;
46 // LocalThreadSizeNC: determines total number of thread per workgroup for the non-contracting dimension
47 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1;
48 // TileSizeDimNC: determines the tile size for the non-contracting dimension
49 static EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = NCWindow / NCFactor;
50 // TileSizeDimC: determines the tile size for the contracting dimension
51 static EIGEN_CONSTEXPR StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC;
52 // WorkLoadPerThreadNC : determines workload per thread for loading the non-contracting dimension
53 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC;
54 // WorkLoadPerThreadC: determines workload per thread for loading the non-contracting dimension
55 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC;
56 // BC : determines if supporting bank conflict is required
57 static EIGEN_CONSTEXPR bool BC = false;
58};
59#endif
60
78template <typename Scalar, typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
79struct TTPanelSize {
80 // TileSizeDimK: determines Tile size for dimension K. The packet size is assumed to be considered
81 static EIGEN_CONSTEXPR StorageIndex TileSizeDimK = TSDK;
82 // WorkLoadPerThreadM : determines workload per thread for loading the M dimension This can be varied based on the
83 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro//
84#ifndef EIGEN_SYCL_REG_M
85 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = REG_SIZE_M;
86#else
87 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M;
88#endif
89// WorkLoadPerThreadN : determines workload per thread for loading the N dimension This can be varied based on the
90// available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro
91#ifndef EIGEN_SYCL_REG_N
92 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = REG_SIZE_N;
93#else
94 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N;
95#endif
96 // LocalThreadSizeM: determines total number of thread per workgroup for the m dimension
97 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0;
98 // LocalThreadSizeN: determines total number of thread per workgroup for the n dimension
99 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1;
100 // TileSizeDimM: determines the tile size for the m dimension
101 static EIGEN_CONSTEXPR StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM;
102 // TileSizeDimN: determines the tile size for the n dimension
103 static EIGEN_CONSTEXPR StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN;
104 // LoadPerThreadLhs: determines workload per thread for loading Lhs Tensor. This must be divisable by packetsize
105 static EIGEN_CONSTEXPR StorageIndex LoadPerThreadLhs =
106 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN));
107 // LoadPerThreadRhs: determines workload per thread for loading Rhs Tensor. This must be divisable by packetsize
108 static EIGEN_CONSTEXPR StorageIndex LoadPerThreadRhs =
109 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM));
110 // BC : determines if supporting bank conflict is required
111 static EIGEN_CONSTEXPR bool BC = true;
112 // DoubleBuffer: determines if double buffering technique should be used (This can be disabled by
113 // EIGEN_SYCL_DISABLE_DOUBLE_BUFFER macro when the device doesnot have sufficient local memory)
114 static EIGEN_CONSTEXPR bool DoubleBuffer =
115#ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
116 false;
117#else
118 true;
119#endif
120};
121
122/* !
123 * \brief contraction_type: an enum class representing the Tensor Contraction implementation algorithm. This is used to
124 * specialize the contraction algorithm based on device support for dedicated local memory.
125 */
126enum class contraction_type { local, no_local };
127/* !
128 * \brief data_source an enum class determining the location of the data in a memory hierarchy (global, local, private).
129 */
130enum class data_source { global_mem, local_mem, private_mem };
131
157template <bool PacketLoad, bool is_coalesced_layout, bool, typename PacketType, typename TensorMapper,
158 typename StorageIndex>
159static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<PacketLoad, PacketType>::type read(
160 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld) {
161 const StorageIndex row = (is_coalesced_layout) ? NCIndex : CIndex;
162 const StorageIndex col = (is_coalesced_layout) ? CIndex : NCIndex;
163 return tensorMapper.get_tensor().template packet<Unaligned>(row + (col * ld));
164}
165
188template <bool PacketLoad, bool, bool IsRhs, typename PacketType, typename TensorMapper, typename StorageIndex>
189static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!PacketLoad, PacketType>::type read(
190 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &) {
191 const StorageIndex row = (IsRhs) ? CIndex : NCIndex;
192 const StorageIndex col = (IsRhs) ? NCIndex : CIndex;
193 return tensorMapper(row, col);
194}
195
217template <typename StorageIndex, StorageIndex ld, data_source dt, typename PacketType, typename DataScalar>
218static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
219 typename ::Eigen::internal::enable_if<dt != data_source::global_mem, void>::type
220 write(PacketType &packet_data, DataScalar ptr) {
221 EIGEN_CONSTEXPR int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size;
222 EIGEN_UNROLL_LOOP
223 for (int i = 0; i < PacketSize; i++) {
224 *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(i, packet_data);
225 ptr += ld;
226 }
227}
228
244template <data_source dt, typename PacketType, typename DataScalar>
245static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
246 Eigen::internal::unpacket_traits<PacketType>::size != 1 && dt == data_source::global_mem, void>::type
247write(PacketType &packet_data, DataScalar *ptr) {
248 ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
249}
250
264template <data_source dt, typename PacketType, typename DataScalar>
265static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
266 Eigen::internal::unpacket_traits<PacketType>::size == 1 && dt == data_source::global_mem, void>::type
267write(PacketType &packet_data, DataScalar *ptr) {
268 *ptr = packet_data;
269}
270
276template <bool is_internal>
277EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary(bool) {
278 return true;
279}
280
286template <>
287EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary<false>(bool cond) {
288 return cond;
289}
290
317template <bool is_transposed, bool is_rhs_, bool packet_load_, typename PacketType>
318struct BlockProperties {
319 static EIGEN_CONSTEXPR bool packet_load = packet_load_;
320 typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar;
321 static EIGEN_CONSTEXPR bool is_rhs = is_rhs_;
322 typedef typename Eigen::internal::conditional<packet_load, PacketType, OutScalar>::type OutType;
323 static EIGEN_CONSTEXPR int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size;
324 static EIGEN_CONSTEXPR bool is_coalesced_layout = !(is_transposed ^ is_rhs);
325 static EIGEN_CONSTEXPR int nc_stride = (is_coalesced_layout ? elements_per_access : 1);
326 static EIGEN_CONSTEXPR int c_stride = (is_coalesced_layout ? 1 : elements_per_access);
327};
328
368template <typename StorageIndex>
369struct ThreadProperties {
370 const StorageIndex linearLocalThreadId;
371 const StorageIndex kGroupId;
372 const StorageIndex mGroupOffset;
373 const StorageIndex nGroupOffset;
374 const StorageIndex kGroupOffset;
375 const StorageIndex mLocalOffset;
376 const StorageIndex nLocalOffset;
377 const StorageIndex mGlobalOffset;
378 const StorageIndex nGlobalOffset;
379 StorageIndex kSize;
380 const bool is_internal;
381 // this is used to adjust the last block
382 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ThreadProperties(
383 const StorageIndex linearLocalThreadId_, const StorageIndex kGroupId_, const StorageIndex mGroupOffset_,
384 const StorageIndex nGroupOffset_, const StorageIndex kGroupOffset_, const StorageIndex mLocalOffset_,
385 const StorageIndex nLocalOffset_, const StorageIndex mGlobalOffset_, const StorageIndex nGlobalOffset_,
386 StorageIndex kSize_, const bool is_internal_)
387 : linearLocalThreadId(linearLocalThreadId_),
388 kGroupId(kGroupId_),
389 mGroupOffset(mGroupOffset_),
390 nGroupOffset(nGroupOffset_),
391 kGroupOffset(kGroupOffset_),
392 mLocalOffset(mLocalOffset_),
393 nLocalOffset(nLocalOffset_),
394 mGlobalOffset(mGlobalOffset_),
395 nGlobalOffset(nGlobalOffset_),
396 kSize(kSize_),
397 is_internal(is_internal_) {}
398};
399
450template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
451 typename RhsMapper, typename StorageIndex, typename Properties, typename TripleDim, bool Vectorizable,
452 typename input_mapper_properties, bool IsFinal, contraction_type contraction_tp>
453class TensorContractionKernel {
454 public:
455 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
456 PacketReturnType;
457 static EIGEN_CONSTEXPR int PacketSize =
458 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
459 static EIGEN_CONSTEXPR bool is_lhs_transposed =
460 !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous;
461 static EIGEN_CONSTEXPR bool is_rhs_transposed =
462 !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous;
463
464 typedef BlockProperties<is_lhs_transposed, false, input_mapper_properties::is_lhs_matrix && Vectorizable,
465 PacketReturnType>
466 LHSBlockProperties;
467
468 typedef BlockProperties<is_rhs_transposed, true, input_mapper_properties::is_rhs_matrix && Vectorizable,
469 PacketReturnType>
470 RHSBlockProperties;
471
472 static EIGEN_CONSTEXPR StorageIndex NStride =
473 contraction_tp == contraction_type::local ? Properties::WorkLoadPerThreadN : RHSBlockProperties::nc_stride;
474
475 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
476 typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr;
477 typedef OutScalar * /*cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::private_space>*/ private_ptr;
478 typedef
479 typename ::Eigen::internal::conditional<contraction_tp == contraction_type::local, local_ptr, private_ptr>::type
480 tile_ptr;
481 static EIGEN_CONSTEXPR StorageIndex LSDL = contraction_tp == contraction_type::local
482 ? Properties::TileSizeDimM + Properties::BC
483 : Properties::WorkLoadPerThreadM;
484 static EIGEN_CONSTEXPR StorageIndex LSDR = contraction_tp == contraction_type::local
485 ? Properties::TileSizeDimN + Properties::BC
486 : Properties::WorkLoadPerThreadN;
487 static EIGEN_CONSTEXPR StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
488
501 template <contraction_type, StorageIndex>
502 struct MemHolder {
503 tile_ptr ptr;
504 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {}
505 };
509 template <StorageIndex MemSize>
510 struct MemHolder<contraction_type::no_local, MemSize> {
511 OutScalar ptr[MemSize] = {OutScalar{0}};
512 };
535 struct TiledMemory {
536 MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK> lhs_scratch_extract;
537 MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK> rhs_scratch_extract;
538 tile_ptr lhs_scratch_ptr_compute;
539 tile_ptr rhs_scratch_ptr_compute;
540 const std::pair<StorageIndex, StorageIndex> lhs_extract_index;
541 const std::pair<StorageIndex, StorageIndex> rhs_extract_index;
542 template <contraction_type tp = contraction_tp>
543 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
544 TiledMemory(const ThreadProperties<StorageIndex> &, local_ptr,
545 typename ::Eigen::internal::enable_if<tp == contraction_type::no_local>::type * = 0)
546 : lhs_scratch_extract{},
547 rhs_scratch_extract{},
548 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr),
549 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr),
550 lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
551 rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
552
553 template <contraction_type tp = contraction_tp>
554 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
555 TiledMemory(const ThreadProperties<StorageIndex> &thread_properties, local_ptr block_start_ptr,
556 typename ::Eigen::internal::enable_if<tp == contraction_type::local>::type * = 0)
557 : lhs_scratch_extract{block_start_ptr},
558 rhs_scratch_extract{lhs_scratch_extract.ptr +
559 ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)},
560 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset),
561 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset),
562 lhs_extract_index(
563 local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
564 rhs_extract_index(
565 local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
566 };
567
568 Scratch scratch;
569 const LhsMapper lhs;
570 const RhsMapper rhs;
571 OutAccessor out_res;
572 const StorageIndex groupSizeM;
573 const StorageIndex groupSizeN;
574 const StorageIndex numTiles;
575 const TripleDim triple_dim;
576
577 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
578 const RhsMapper rhs_, OutAccessor out_res_,
579 const StorageIndex groupSizeM_,
580 const StorageIndex groupSizeN_,
581 const StorageIndex numTiles_,
582 const TripleDim triple_dim_)
583 : scratch(scratch_),
584 lhs(lhs_),
585 rhs(rhs_),
586 out_res(out_res_),
587 groupSizeM(groupSizeM_),
588 groupSizeN(groupSizeN_),
589 numTiles(numTiles_),
590 triple_dim(triple_dim_) {}
591
592 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
593 const RhsMapper rhs_, OutAccessor out_res_,
594 const StorageIndex groupSizeM_,
595 const StorageIndex numTiles_,
596 const TripleDim triple_dim_)
597 : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {}
598
599 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
600 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
601 const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
602 const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
603 const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM;
604 const StorageIndex tmp = itemID.get_group(0) / groupSizeM;
605 const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN;
606 const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN;
607 const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
608 const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
609 const StorageIndex mLocalOffset = PacketSize * mLocalThreadId;
610 const StorageIndex nLocalOffset = NStride * nLocalThreadId;
611 const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
612 const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
613
614 const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK;
615 StorageIndex kGroupOffset = kGroupId * kSizePerWG;
616 const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
617 triple_dim.N - nGroupOffset >= Properties::TileSizeDimN &&
618 triple_dim.K - kGroupOffset >= kSizePerWG;
619 // this is used to adjust the last block
620 StorageIndex kSize = IsFinal ? triple_dim.K : std::min(kSizePerWG, triple_dim.K - kGroupOffset);
621 // This is used to find out the lats K offset so that kGroupOffset -kSize can compute the coffset for loading to
622 // tile
623 kGroupOffset += kSize;
624
625 auto thread_properties =
626 ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset,
627 mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
628
629 auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N);
630
631 (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
632 : compute_panel<false>(itemID, thread_properties, out_ptr);
633 }
634 // The compute block computes the contraction operation private block for each thread and store the resutl in the
635 // privateRes memory of Each computation the compute block function is independent of local and no local concepts as
636 // it only compute the block on each thread's private memory space
637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr,
638 PacketReturnType *privateRes) {
639 StorageIndex idx = 0;
640 EIGEN_CONSTEXPR StorageIndex lhs_stride =
641 contraction_tp == contraction_type::local ? (PacketSize * Properties::LocalThreadSizeM) : 1;
642 EIGEN_UNROLL_LOOP
643 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
644 auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)};
645 StorageIndex lhs_index = 0;
646 EIGEN_UNROLL_LOOP
647 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
648 PacketReturnType lhsPack{};
649 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack,
650 lhs_block_ptr + lhs_index);
651 privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]);
652
653 lhs_index += lhs_stride;
654 idx++;
655 }
656 }
657 }
658 // The store function write the computed contraction operation in the private memory of each thread to the global
659 // memory. The store function is independent of local and no local concepts s that it can be abstract out in the base
660 // class.
661 template <bool is_internal_block, StorageIndex PrivateNStride, typename OutPtr>
662 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void store(OutPtr *out_ptr, PacketReturnType *privateRes,
663 StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) {
664 auto chk_bound = [&](const StorageIndex &mIndex, const StorageIndex &nIndex) EIGEN_DEVICE_FUNC {
665 return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N);
666 };
667 // when local memory is not used M and N are both accessed in a coalesced way. However, when local memory is
668 // available the k*N is transposed in the local to N*K therefore, each blocks operates on blockId*
669 // WorkLoadPerThreadN slice of N
670 EIGEN_CONSTEXPR StorageIndex GlobalNStride =
671 contraction_tp == contraction_type::local ? 1 : Properties::LocalThreadSizeN;
672 EIGEN_UNROLL_LOOP
673 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
674 // output leading dimension
675 StorageIndex outputLD = 0;
676 // When local memory is used the PrivateNstride is always 1 because the coalesed access on N is loaded into Local
677 // memory and extracting from local to global is the same as no transposed version. However, when local memory is
678 // not used and RHS is transposed we packetize the load for RHS.
679 EIGEN_UNROLL_LOOP
680 for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
681 StorageIndex globalRow = mGlobalOffset;
682 EIGEN_UNROLL_LOOP
683 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
684 PacketReturnType privetOut = privateRes[wLPTM];
685 if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
686 // Store the final results in C. The C matrix has always M as a first StorageIndex and N as a second
687 // StorageIndex Therefore it is always coalesced layout
688 write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
689 } else {
690 EIGEN_UNROLL_LOOP
691 for (StorageIndex mId = 0; mId < PacketSize; mId++) {
692 StorageIndex mOffset = globalRow + mId;
693 if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) {
694 out_ptr[mOffset + outputLD] =
695 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut);
696 }
697 }
698 }
699 globalRow += (PacketSize * Properties::LocalThreadSizeM);
700 }
701 outputLD += triple_dim.M;
702 privateRes += Properties::WorkLoadPerThreadM / PacketSize;
703 }
704 out_ptr += (GlobalNStride * outputLD);
705
706 nGlobalOffset += (PrivateNStride * GlobalNStride);
707 }
708 }
709 // when no local memory is used the following extract_block will be enabled
710 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename PrivateReg,
711 contraction_type contract_tp = contraction_tp>
712 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
713 typename ::Eigen::internal::enable_if<contract_tp == contraction_type::no_local>::type
714 extract_block(const Input &inpt, PrivateReg private_ptr, const std::pair<StorageIndex, StorageIndex> &,
715 const StorageIndex &ncOffset, const StorageIndex cOffset) {
716 EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC =
717 InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
718 EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC =
719 InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
720 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
721
722 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
723 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
724 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
725 };
726 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
727 StorageIndex cIndex = cOffset;
728
729 EIGEN_UNROLL_LOOP
730 for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
731 StorageIndex ncIndex = ncOffset;
732 EIGEN_UNROLL_LOOP
733 for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
734 if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
735 auto val =
736 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
737 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
738
739 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
740 data_source::private_mem>(val, private_ptr);
741 } else {
742 EIGEN_UNROLL_LOOP
743 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
744 const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
745 const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
746 OutScalar val =
747 (ncInd < NC && cInd < triple_dim.K)
748 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
749 inpt, ncInd, cInd, ld)
750 : OutScalar(0);
751 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
752 data_source::private_mem>(
753 val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) +
754 ((InputBlockProperties::is_coalesced_layout ? 0 : i) * WorkLoadPerThreadNC));
755 }
756 }
757
758 // if it is lhs we have to load it packetised when the packet size is > 1, because the output is coalesced. So
759 // even if M is not accessed in a coalesced mode, we have to load packet_size number of m per thread.
760 ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1)
761 ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC
762 : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
763 private_ptr += InputBlockProperties::nc_stride;
764 }
765 // the previous for loop ( private_ptr += (ncId * nc_stride)) has already moved ptr with one WorkLoadPerThreadNC
766 private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
767 cIndex += InputBlockProperties::c_stride;
768 }
769 }
770 template <typename InputBlockProperties, StorageIndex TileSizeDimNC>
771 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract(
772 const StorageIndex &linearLocalThreadId) {
773 const StorageIndex localThreadNC =
774 (InputBlockProperties::is_coalesced_layout)
775 ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
776 : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
777 const StorageIndex localThreadC =
778 (InputBlockProperties::is_coalesced_layout)
779 ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
780 : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
781 return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
782 }
783
784 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
785 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
786 typename ::Eigen::internal::enable_if<db && ctp == contraction_type::local>::type
787 sync_mem(const cl::sycl::nd_item<1> &, bool &db_offset) noexcept {
788 db_offset = !db_offset;
789 }
790
791 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
792 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
793 typename ::Eigen::internal::enable_if<!db && ctp == contraction_type::local>::type
794 sync_mem(const cl::sycl::nd_item<1> &itemID, bool &) noexcept {
795 itemID.barrier(cl::sycl::access::fence_space::local_space);
796 }
797
798 template <contraction_type ctp = contraction_tp>
799 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
800 typename ::Eigen::internal::enable_if<ctp == contraction_type::no_local>::type
801 sync_mem(const cl::sycl::nd_item<1> &, bool &) noexcept {
802 return;
803 }
804
805 template <bool need_sync, contraction_type ctp = contraction_tp>
806 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
807 typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::no_local>::type
808 sync_thread(const cl::sycl::nd_item<1> &
809#ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
810 itemID
811#endif
812 ) noexcept {
813#ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
814 itemID.barrier(cl::sycl::access::fence_spacce::local_space);
815#else
816 return;
817#endif
818 }
819 template <bool need_sync, contraction_type ctp = contraction_tp>
820 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
821 typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::local>::type
822 sync_thread(const cl::sycl::nd_item<1> &itemID) {
823 itemID.barrier(cl::sycl::access::fence_space::local_space);
824 }
825 template <bool need_sync>
826 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!need_sync>::type sync_thread(
827 const cl::sycl::nd_item<1> &) {
828 return;
829 }
830
831 template <bool is_internal_block>
832 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_tile_per_panel(const cl::sycl::nd_item<1> &itemID,
833 ThreadProperties<StorageIndex> &thread_properties,
834 TiledMemory &tiled_input_block,
835 PacketReturnType *privateRes, bool &db_offset) {
836 // Tiling the Rhs block from global to local memory
837 extract_block<RHSBlockProperties, is_internal_block>(
838 rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR),
839 tiled_input_block.rhs_extract_index,
840 contraction_tp == contraction_type::local ? thread_properties.nGroupOffset : thread_properties.nGlobalOffset,
841 thread_properties.kGroupOffset - thread_properties.kSize);
842
843 sync_thread<contraction_tp == contraction_type::no_local>(itemID);
844
845 // Tiling the Lhs block from global to local memory
846 extract_block<LHSBlockProperties, is_internal_block>(
847 lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK),
848 tiled_input_block.lhs_extract_index,
849 contraction_tp == contraction_type::local ? thread_properties.mGroupOffset : thread_properties.mGlobalOffset,
850 thread_properties.kGroupOffset - thread_properties.kSize);
851
852 // itemID.barrier(cl::sycl::access::fence_space::local_space);
853 sync_thread<contraction_tp == contraction_type::local>(itemID);
854 // switch to compute mede
855 StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK);
856 StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR);
857 // Loop over the values of a single tile
858 for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
859 compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset,
860 tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
861 lhs_offset += LSDL;
862 rhs_offset += LSDR;
863 }
864 // computing the K index for the next tile
865 thread_properties.kSize -= Properties::TileSizeDimK;
866 sync_mem(itemID, db_offset);
867 }
868
869 // when local memory is available the following compute_panel will be enabled
870 template <bool is_internal_block, typename OutPtr>
871 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(const cl::sycl::nd_item<1> &itemID,
872 ThreadProperties<StorageIndex> &thread_properties,
873 OutPtr out_ptr) {
874 auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()};
875 // Allocate register space
876 PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = {
877 PacketReturnType{0}};
878 bool db_offset = 0;
879
880 while (thread_properties.kSize >= Properties::TileSizeDimK) {
881 compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
882 }
883 if (thread_properties.kSize > 0) {
884 compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
885 }
886
887 // Storing the final results in the output
888 store<is_internal_block,
889 contraction_tp == contraction_type::local ? static_cast<StorageIndex>(1) : RHSBlockProperties::nc_stride>(
890 out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset,
891 thread_properties.nGlobalOffset);
892 }
893 // When local memory is available the following extract_block will be enabled
894 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename Local,
895 contraction_type contract_tp = contraction_tp>
896 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
897 typename ::Eigen::internal::enable_if<contract_tp == contraction_type::local>::type
898 extract_block(const Input &inpt, Local local_ptr, const std::pair<StorageIndex, StorageIndex>& local_index,
899 const StorageIndex &ncOffset, const StorageIndex cOffset) {
900 EIGEN_CONSTEXPR StorageIndex TileSizeDimNC =
901 InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
902 EIGEN_CONSTEXPR StorageIndex LoadPerThread =
903 InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
904 EIGEN_CONSTEXPR StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL;
905 static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
906 (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
907 " LocalOffset must be divisable by stride");
908 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
909 StorageIndex localThreadNC = local_index.first;
910 StorageIndex localThreadC = local_index.second;
911 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
912 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
913 (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
914 };
915 EIGEN_UNROLL_LOOP
916 for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
917 const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
918 const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
919 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
920 if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
921 auto val =
922 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
923 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
924 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
925 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
926 (InputBlockProperties::c_stride * localThreadC * LSD));
927 } else {
928 EIGEN_UNROLL_LOOP
929 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
930 const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
931 const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
932 OutScalar val =
933 (nCInd < NC && cInd < triple_dim.K)
934 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
935 inpt, nCInd, cInd, ld)
936 : OutScalar(0);
937
938 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
939 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
940 (InputBlockProperties::is_coalesced_layout ? i : 0) +
941 ((InputBlockProperties::c_stride * localThreadC +
942 (InputBlockProperties::is_coalesced_layout ? 0 : i)) *
943 LSD));
944 }
945 }
946 localThreadNC += (InputBlockProperties::is_coalesced_layout)
947 ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
948 : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
949 localThreadC += (InputBlockProperties::is_coalesced_layout)
950 ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
951 : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
952 }
953 }
954};
955
956#ifndef EIGEN_SYCL_DISABLE_GEMV
957
999template <typename OutScalar, typename OutAccessor, typename VectorMapper, typename TensorMapper, typename StorageIndex,
1000 typename Properties, StorageIndex KFactor, bool Vectorizable, bool is_lhs_vec, bool IsFinal>
1001struct GeneralVectorTensor {
1002 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
1003 PacketReturnType;
1004 static EIGEN_CONSTEXPR int PacketSize =
1005 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
1006 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1007
1008 static EIGEN_CONSTEXPR StorageIndex OutScratchOffset =
1009 KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1010
1011 // Since the access layout for a vector can always be coalesced, when LHS is a vector, we pass false and false to make
1012 // sure that the !^ is true When RHS is a vector, we pass true and true to make sure that the !^ is true.
1013 typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType>
1014 VecBlockProperties;
1015
1016 Scratch scratch;
1017 const VectorMapper vec;
1018 const TensorMapper mat;
1019 OutAccessor out_res;
1020 const StorageIndex nonContractGroupSize;
1021 const StorageIndex nonContractDim;
1022 const StorageIndex contractDim;
1023
1024 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_, const VectorMapper vec_,
1025 const TensorMapper mat_, OutAccessor out_res_,
1026 const StorageIndex nonContractGroupSize_,
1027 const StorageIndex nonContractDim_,
1028 const StorageIndex contractDim_)
1029 : scratch(scratch_),
1030 vec(vec_),
1031 mat(mat_),
1032 out_res(out_res_),
1033 nonContractGroupSize(nonContractGroupSize_),
1034 nonContractDim(nonContractDim_),
1035 contractDim(contractDim_) {}
1036
1037 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
1038 auto scratch_ptr = scratch.get_pointer();
1039 const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1040 StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1041 : linearLocalThreadId % Properties::LocalThreadSizeNC;
1042 StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1043 : linearLocalThreadId / Properties::LocalThreadSizeNC;
1044 const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize;
1045 const StorageIndex nonContractGroupId =
1046 is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize;
1047 const StorageIndex contractGroupId =
1048 is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize;
1049 auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : contractGroupId * nonContractDim);
1050
1051 const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1052 const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1053 auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1054 const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1055 const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1056 auto local_output = scratch_ptr + OutScratchOffset;
1057 const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1058 contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1059 is_internal
1060 ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr,
1061#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1062 scratch_ptr, contractGroupOffset,
1063#endif
1064 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1065 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1066 : compute_panel<false>(itemID, vec, mat, local_output, out_ptr,
1067#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1068 scratch_ptr, contractGroupOffset,
1069#endif
1070 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1071 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1072 }
1073 template <bool is_internal_block, typename OutPtr>
1074 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(
1075 const cl::sycl::nd_item<1> &itemID, const VectorMapper &vec, const TensorMapper &mat, OutScalar *local_output,
1076 OutPtr out_ptr,
1077#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1078 OutScalar *scratch_ptr, const StorageIndex contractGroupOffset,
1079#endif
1080 const StorageIndex nonContractGroupOffset, const StorageIndex linearLocalThreadId, StorageIndex contractDim,
1081 StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1082 StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1083 OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1084 // Reading the vector
1085#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1086 const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1087 extract_block<VecBlockProperties, is_internal_block, KFactor,
1088 Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId,
1089 vectorOffset, contractDim);
1090
1091 itemID.barrier(cl::sycl::access::fence_space::local_space);
1092 auto in_scratch_ptr = scratch_ptr + contractId;
1093#endif
1094
1095 StorageIndex privateOffsetC = 0;
1096 EIGEN_UNROLL_LOOP
1097 for (StorageIndex i = 0; i < Properties::WorkLoadPerThreadC; i++) {
1098 StorageIndex privateOffsetNC = 0;
1099 bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim);
1100#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1101 auto vecScalar = *in_scratch_ptr;
1102#else
1103 auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1104 ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1105 is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1106 : OutScalar(0);
1107#endif
1108 EIGEN_UNROLL_LOOP
1109 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1110 auto matScalar = (check_boundary<is_internal_block>(
1111 contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim)))
1112 ? mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1113 : globalNonContractDimOffset + privateOffsetNC,
1114 is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1115 : globalContractDimOffset + privateOffsetC)
1116 : OutScalar(0);
1117
1118 outScalar[j] = cl::sycl::mad(matScalar, vecScalar, outScalar[j]);
1119 privateOffsetNC += Properties::LocalThreadSizeNC;
1120 }
1121 privateOffsetC += Properties::LocalThreadSizeC;
1122#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1123 in_scratch_ptr += Properties::LocalThreadSizeC;
1124#endif
1125 }
1126
1127 auto out_scratch_ptr = local_output + outScratchIndex;
1128 // Each block of 16*16 element in shared memory should reduce to 16*1
1129 EIGEN_UNROLL_LOOP
1130 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1131 *out_scratch_ptr = outScalar[j];
1132
1133 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1134 }
1135 if (is_lhs_vec) {
1136 nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1137 contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1138 outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1139 }
1140
1141 out_scratch_ptr = local_output + outScratchIndex;
1142 EIGEN_UNROLL_LOOP
1143 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1144 EIGEN_UNROLL_LOOP
1145 for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) {
1146 itemID.barrier(cl::sycl::access::fence_space::local_space);
1147 if (contractId < offset) {
1148 StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset);
1149 *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1150 }
1151 }
1152 // moving to next 16 by 16 block
1153 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1154 }
1155
1156 if (contractId == 0) {
1157 out_scratch_ptr = local_output + nonContractId;
1158 StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1159 out_ptr += global_final_offset;
1160 EIGEN_UNROLL_LOOP
1161 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1162 if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) {
1163 auto res = *out_scratch_ptr;
1164
1165 *out_ptr = res;
1166 out_ptr += Properties::LocalThreadSizeNC;
1167 }
1168 // moving to next 16 by 16 block to ge the next 16 reduced elements
1169 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1170 if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1171 }
1172 }
1173 }
1174
1175 template <typename InputBlockProperties, bool is_internal_block, int CFactor, int GroupSize, typename Input,
1176 typename Local>
1177 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_block(const Input &inpt, Local *local_ptr,
1178 const StorageIndex &linearLocalThreadId,
1179 const StorageIndex &cOffset, const StorageIndex &C) {
1180 local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1181 StorageIndex cIndex = cOffset;
1182 for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1183 if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) {
1184 auto val = read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1185 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1186 cIndex, StorageIndex(1));
1187 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1188 } else {
1189 EIGEN_UNROLL_LOOP
1190 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
1191 OutScalar val =
1192 (cIndex + i < C)
1193 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1194 inpt, StorageIndex(0), cIndex + i, StorageIndex(1))
1195 : OutScalar(0);
1196 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i);
1197 }
1198 }
1199 local_ptr += InputBlockProperties::c_stride * GroupSize;
1200 cIndex += InputBlockProperties::c_stride * GroupSize;
1201 }
1202 }
1203};
1204#endif
1205
1206#ifndef EIGEN_SYCL_DISABLE_SCALAR
1207
1239template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
1240 typename RhsMapper, typename StorageIndex, bool Vectorizable>
1241struct GeneralScalarContraction {
1242 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1243 Scratch scratch;
1244 const LhsMapper lhs;
1245 const RhsMapper rhs;
1246 OutAccessor out_res;
1247 const StorageIndex rng;
1248
1249 EIGEN_DEVICE_FUNC
1250 GeneralScalarContraction(Scratch scratch_, const LhsMapper lhs_, const RhsMapper rhs_, OutAccessor out_res_,
1251 const StorageIndex rng_)
1252 : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {}
1253
1254 EIGEN_DEVICE_FUNC void operator()(cl::sycl::nd_item<1> itemID) {
1255 auto out_ptr = out_res.get_pointer();
1256 auto scratch_ptr = scratch.get_pointer().get();
1257
1258 StorageIndex globalid = itemID.get_global_id(0);
1259 StorageIndex localid = itemID.get_local_id(0);
1260 OutScalar accumulator = OutScalar(0);
1261 for (StorageIndex i = globalid; i < rng; i += itemID.get_global_range(0)) {
1262 accumulator = cl::sycl::mad(lhs(0, i), rhs(i, 0), accumulator);
1263 }
1264 auto out_scratch_ptr = scratch_ptr + localid;
1265 *out_scratch_ptr = accumulator;
1266 for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) {
1267 itemID.barrier(cl::sycl::access::fence_space::local_space);
1268 if (localid < offset) {
1269 *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]);
1270 }
1271 }
1272 if (localid == 0) {
1273 out_ptr[itemID.get_group(0)] = accumulator;
1274 }
1275 }
1276};
1277#endif
1278
1279} // namespace internal
1280} // namespace TensorSycl
1281
1282template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1283struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
1284 Eigen::SyclDevice>
1285 : public TensorContractionEvaluatorBase<TensorEvaluator<
1286 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1287 static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1288 "SYCL tensor contraction does not support output kernels.");
1289
1290 typedef Eigen::SyclDevice Device;
1291
1292 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1293 typedef TensorContractionEvaluatorBase<Self> Base;
1294 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1295 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1296 typedef typename XprType::Index StorageIndex;
1297 typedef typename XprType::CoeffReturnType CoeffReturnType;
1298 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
1299 typedef typename Base::Storage Storage;
1300 typedef typename Base::EvaluatorPointerType EvaluatorPointerType;
1301 struct TripleDim {
1302 const StorageIndex M;
1303 const StorageIndex N;
1304 const StorageIndex K;
1305 TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_) : M(M_), N(N_), K(K_) {}
1306 };
1307 enum {
1308 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1309 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
1310 BlockAccess = false,
1311 };
1312
1313 static EIGEN_CONSTEXPR int LDims = Base::LDims;
1314 static EIGEN_CONSTEXPR int RDims = Base::RDims;
1315 static EIGEN_CONSTEXPR int ContractDims = Base::ContractDims;
1316
1317 typedef array<StorageIndex, LDims> left_dim_mapper_t;
1318 typedef array<StorageIndex, RDims> right_dim_mapper_t;
1319
1320 typedef array<StorageIndex, ContractDims> contract_t;
1321 typedef array<StorageIndex, LDims - ContractDims> left_nocontract_t;
1322 typedef array<StorageIndex, RDims - ContractDims> right_nocontract_t;
1323
1324 static const int NumDims = LDims + RDims - 2 * ContractDims;
1325
1326 typedef DSizes<StorageIndex, NumDims> Dimensions;
1327
1328 typedef TensorEvaluator<typename Base::EvalLeftArgType, Device> LeftEvaluator;
1329 typedef TensorEvaluator<typename Base::EvalRightArgType, Device> RightEvaluator;
1330 typedef typename Eigen::internal::remove_const<typename LeftEvaluator::CoeffReturnType>::type LhsScalar;
1331 typedef typename Eigen::internal::remove_const<typename RightEvaluator::CoeffReturnType>::type RhsScalar;
1332
1333 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1334 typedef typename RightEvaluator::Dimensions RightDimensions;
1335
1336 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered>
1337 struct input_mapper_propertis {
1338 static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1339 static EIGEN_CONSTEXPR bool is_rhs_matrix =
1340 (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1341 };
1342
1343 TensorEvaluator(const XprType &op, const Device &device) : Base(op, device) {}
1344
1345 // We need to redefine this method to make nvcc happy
1346 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data) {
1347 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1348 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1349 if (!data) {
1350 this->m_result = this->m_device.get(
1351 static_cast<Scalar *>(this->m_device.allocate_temp(this->dimensions().TotalSize() * sizeof(Scalar))));
1352 data = this->m_result;
1353 }
1354 evalToSycl(data);
1355 return (this->m_result != NULL);
1356 }
1357 const Eigen::SyclDevice &device() const { return this->m_device; }
1358 void evalToSycl(typename Base::EvaluatorPointerType buffer) const {
1359 if (this->m_lhs_inner_dim_contiguous) {
1360 if (this->m_rhs_inner_dim_contiguous) {
1361 if (this->m_rhs_inner_dim_reordered) {
1362 evalTyped<true, true, true, Unaligned>(buffer);
1363 } else {
1364 evalTyped<true, true, false, Unaligned>(buffer);
1365 }
1366 } else {
1367 if (this->m_rhs_inner_dim_reordered) {
1368 evalTyped<true, false, true, Unaligned>(buffer);
1369 } else {
1370 evalTyped<true, false, false, Unaligned>(buffer);
1371 }
1372 }
1373 } else {
1374 if (this->m_rhs_inner_dim_contiguous) {
1375 if (this->m_rhs_inner_dim_reordered) {
1376 evalTyped<false, true, true, Unaligned>(buffer);
1377 } else {
1378 evalTyped<false, true, false, Unaligned>(buffer);
1379 }
1380 } else {
1381 if (this->m_rhs_inner_dim_reordered) {
1382 evalTyped<false, false, true, Unaligned>(buffer);
1383 } else {
1384 evalTyped<false, false, false, Unaligned>(buffer);
1385 }
1386 }
1387 }
1388 }
1389
1390 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1391 void evalTyped(typename Base::EvaluatorPointerType buffer) const {
1392 const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1393 typedef internal::TensorContractionInputMapper<
1394 LhsScalar, StorageIndex, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t,
1395 PacketType<CoeffReturnType, Device>::size, lhs_inner_dim_contiguous, false, Unaligned, MakeSYCLPointer>
1396 LhsMapper;
1397
1398 typedef internal::TensorContractionInputMapper<RhsScalar, StorageIndex, internal::Rhs, RightEvaluator,
1399 right_nocontract_t, contract_t,
1400 PacketType<CoeffReturnType, Device>::size, rhs_inner_dim_contiguous,
1401 rhs_inner_dim_reordered, Unaligned, MakeSYCLPointer>
1402 RhsMapper;
1403
1404 // initialize data mappers
1405 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1406 this->m_left_contracting_strides, this->m_k_strides);
1407
1408 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1409 this->m_right_contracting_strides, this->m_k_strides);
1410
1411#ifndef EIGEN_SYCL_DISABLE_SCALAR
1412 if (triple_dim.M == 1 && triple_dim.N == 1) {
1413 launchSC(buffer, lhs, rhs, triple_dim.K);
1414 } else
1415#endif
1416#ifndef EIGEN_SYCL_DISABLE_GEMV
1417 if (triple_dim.M != 1 && triple_dim.N == 1) {
1418 LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1419 } else if (triple_dim.M == 1 && triple_dim.N != 1) {
1420 LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1421 } else // This is equivalent of if (m!=1 && n!=1)
1422#endif
1423 {
1424 typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1425 inpt_mapper_properties;
1426#ifndef EIGEN_SYCL_DISABLE_SKINNY
1427 bool skinny = false;
1428 auto platform_name = this->device().getPlatformName();
1429 // This is based on empirical calculation for AMD r9-nano and Fiji
1430 if (platform_name.find("AMD") == 0) {
1431 skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1432 ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1433 (uint64_t(triple_dim.M * triple_dim.N) < uint64_t(triple_dim.K)));
1434 } else {
1435 skinny = (((std::max(triple_dim.K, triple_dim.N) / std::min(triple_dim.K, triple_dim.N)) > 100) ||
1436 ((std::max(triple_dim.K, triple_dim.M) / std::min(triple_dim.K, triple_dim.M)) > 100) ||
1437 ((std::max(triple_dim.N, triple_dim.M) / std::min(triple_dim.N, triple_dim.M)) > 100));
1438 }
1439 if (skinny)
1440 adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1441 else
1442#endif // EIGEN_SYCL_DISABLE_SKINNY
1443 adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1444 }
1445 }
1446
1447 template <bool skinny, typename input_mapper_properties, typename LhsMapper, typename RhsMapper>
1448 void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1449 const TripleDim &triple_dim) const {
1450#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1451 if (device().has_local_memory()) {
1452 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1453 launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1454 buffer, lhs, rhs, triple_dim);
1455 }
1456#endif
1457#ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1458 if (!(device().has_local_memory())) {
1459 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters;
1460 launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1461 buffer, lhs, rhs, triple_dim);
1462 }
1463#endif
1464 }
1465
1466 template <TensorSycl::internal::contraction_type ct, bool skinny, typename input_mapper_properties,
1467 typename Properties, typename LhsMapper, typename RhsMapper>
1468 void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1469 const TripleDim &triple_dim) const {
1470 const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1471 const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1472 const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1473 const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1474
1475 const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1476 StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1477 StorageIndex groupSizeK =
1478 skinny
1479 ? std::max(std::min(totalTilesK,
1480 (StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(), true) * 4) /
1481 (groupSizeM * groupSizeN)),
1482 StorageIndex(1))
1483 : StorageIndex(1);
1484
1485 const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1486
1487 const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1488
1489 const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1490 const StorageIndex globalRange = totalGroupSize * localRange;
1491
1492 const StorageIndex scratchSize = (ct == TensorSycl::internal::contraction_type::local)
1493 ? ((Properties::DoubleBuffer + 1) *
1494 (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1495 ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1496 (Properties::TileSizeDimN + Properties::BC))
1497 : StorageIndex(1);
1498
1499 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1500 if (groupSizeK == 1) {
1501 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1502 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1503 PacketAccess, input_mapper_properties, true, ct>
1504 ContractKernelName;
1505 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1506 lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim);
1507 } else {
1508 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1509 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1510 PacketAccess, input_mapper_properties, false, ct>
1511 ContractKernelName;
1512 CoeffReturnType *temp_pointer = static_cast<CoeffReturnType *>(
1513 device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK * sizeof(CoeffReturnType)));
1514 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1515
1516 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1517 lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1518 triple_dim);
1519
1520 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1521 auto op = Op();
1522 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1523 EvaluatorPointerType, Op>
1524 ReductionKernel;
1525
1526 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1527 tmp_global_accessor, buffer,
1528 cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex(
1529 Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1530 cl::sycl::range<1>(localRange)),
1531 StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK);
1532
1533 device().deallocate_temp(temp_pointer);
1534 }
1535 }
1536
1537#ifndef EIGEN_SYCL_DISABLE_GEMV
1538 template <bool is_lhs_vec, typename VectorMapper, typename TensorMapper, typename StorageIndex>
1539 void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat,
1540 StorageIndex NC, StorageIndex C) const {
1541 const StorageIndex nonContractDim = NC;
1542 EIGEN_CONSTEXPR StorageIndex NCFactor = 1;
1543 EIGEN_CONSTEXPR StorageIndex CFactor = 1;
1544 EIGEN_CONSTEXPR StorageIndex NCWindow = 16;
1545 typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor>
1546 Properties;
1547 const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC);
1548 const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1549 const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1550 const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1551 const StorageIndex globalRange =
1552 (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1553 const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1554 const StorageIndex scratchSize =
1555 (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1556 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1557 if (cNumGroups > 1) {
1558 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1559 TensorMapper, StorageIndex, Properties, CFactor, false,
1560 is_lhs_vec, false>
1561 ContractKernelName;
1562 CoeffReturnType *temp_pointer =
1563 static_cast<CoeffReturnType *>(device().allocate_temp(nonContractDim * cNumGroups * sizeof(CoeffReturnType)));
1564 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1565
1566 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1567 vec, mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1568
1569 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1570 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1571 EvaluatorPointerType, Op>
1572 ReductionKernel;
1573
1574 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1575 tmp_global_accessor, buffer,
1576 cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1577 cl::sycl::range<1>(localRange)),
1578 StorageIndex(1), Op(), nonContractDim, cNumGroups);
1579
1580 device().deallocate_temp(temp_pointer);
1581 } else {
1582 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1583 TensorMapper, StorageIndex, Properties, CFactor, false,
1584 is_lhs_vec, true>
1585 ContractKernelName;
1586 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1587 vec, mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1588 }
1589 }
1590#endif
1591
1592#ifndef EIGEN_SYCL_DISABLE_SCALAR
1593 template <typename LhsMapper, typename RhsMapper>
1594 EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1595 StorageIndex K) const {
1596 EIGEN_STATIC_ASSERT(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) &
1597 (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1598 "The Local thread size must be a power of 2 for the reduction "
1599 "operation");
1600 EIGEN_CONSTEXPR StorageIndex local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1;
1601
1602 // Here we force the code not to be more than 2-step reduction: Our empirical research shows that if each thread
1603 // reduces at least 512 elementss individually, we get better performance.
1604 const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1605 const StorageIndex global_range = num_work_group * local_range;
1606
1607 typedef Eigen::TensorSycl::internal::GeneralScalarContraction<
1608 CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, LhsMapper, RhsMapper, StorageIndex, false>
1609 ContractKernelName;
1610 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1611 if (num_work_group > 1) {
1612 CoeffReturnType *temp_pointer =
1613 static_cast<CoeffReturnType *>(device().allocate_temp(num_work_group * sizeof(CoeffReturnType)));
1614 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1615 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1616 thread_range, local_range, K);
1617 typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1618 typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType,
1619 EvaluatorPointerType, StorageIndex, local_range>
1620 GenericRKernel;
1621 device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1622 tmp_global_accessor, buffer,
1623 cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op());
1624
1625 device().deallocate_temp(temp_pointer);
1626 } else {
1627 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range,
1628 local_range, K);
1629 }
1630 }
1631#endif
1632
1633 EIGEN_STRONG_INLINE void cleanup() {
1634 this->m_leftImpl.cleanup();
1635 this->m_rightImpl.cleanup();
1636
1637 if (this->m_result) {
1638 this->m_device.deallocate_temp(this->m_result);
1639 this->m_result = NULL;
1640 }
1641 }
1642 // The placeholder accessors must bound to a command group handler for SYCL
1643 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
1644 this->m_leftImpl.bind(cgh);
1645 this->m_rightImpl.bind(cgh);
1646 this->m_result.bind(cgh);
1647 }
1648};
1649} // namespace Eigen
1650#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
Namespace containing all symbols from the Eigen library.