Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorScanSycl.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Mehdi Goli Codeplay Software Ltd.
5// Ralph Potter Codeplay Software Ltd.
6// Luke Iwanski Codeplay Software Ltd.
7// Contact: <eigen@codeplay.com>
8//
9// This Source Code Form is subject to the terms of the Mozilla
10// Public License v. 2.0. If a copy of the MPL was not distributed
11// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12
13/*****************************************************************
14 * TensorScanSycl.h
15 *
16 * \brief:
17 * Tensor Scan Sycl implement the extend version of
18 * "Efficient parallel scan algorithms for GPUs." .for Tensor operations.
19 * The algorithm requires up to 3 stage (consequently 3 kernels) depending on
20 * the size of the tensor. In the first kernel (ScanKernelFunctor), each
21 * threads within the work-group individually reduces the allocated elements per
22 * thread in order to reduces the total number of blocks. In the next step all
23 * thread within the work-group will reduce the associated blocks into the
24 * temporary buffers. In the next kernel(ScanBlockKernelFunctor), the temporary
25 * buffer is given as an input and all the threads within a work-group scan and
26 * reduces the boundaries between the blocks (generated from the previous
27 * kernel). and write the data on the temporary buffer. If the second kernel is
28 * required, the third and final kerenl (ScanAdjustmentKernelFunctor) will
29 * adjust the final result into the output buffer.
30 * The original algorithm for the parallel prefix sum can be found here:
31 *
32 * Sengupta, Shubhabrata, Mark Harris, and Michael Garland. "Efficient parallel
33 * scan algorithms for GPUs." NVIDIA, Santa Clara, CA, Tech. Rep. NVR-2008-003
34 *1, no. 1 (2008): 1-17.
35 *****************************************************************/
36
37#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
38#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
39
40namespace Eigen {
41namespace TensorSycl {
42namespace internal {
43
44#ifndef EIGEN_SYCL_MAX_GLOBAL_RANGE
45#define EIGEN_SYCL_MAX_GLOBAL_RANGE (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 * 4)
46#endif
47
48template <typename index_t>
49struct ScanParameters {
50 // must be power of 2
51 static EIGEN_CONSTEXPR index_t ScanPerThread = 8;
52 const index_t total_size;
53 const index_t non_scan_size;
54 const index_t scan_size;
55 const index_t non_scan_stride;
56 const index_t scan_stride;
57 const index_t panel_threads;
58 const index_t group_threads;
59 const index_t block_threads;
60 const index_t elements_per_group;
61 const index_t elements_per_block;
62 const index_t loop_range;
63
64 ScanParameters(index_t total_size_, index_t non_scan_size_, index_t scan_size_, index_t non_scan_stride_,
65 index_t scan_stride_, index_t panel_threads_, index_t group_threads_, index_t block_threads_,
66 index_t elements_per_group_, index_t elements_per_block_, index_t loop_range_)
67 : total_size(total_size_),
68 non_scan_size(non_scan_size_),
69 scan_size(scan_size_),
70 non_scan_stride(non_scan_stride_),
71 scan_stride(scan_stride_),
72 panel_threads(panel_threads_),
73 group_threads(group_threads_),
74 block_threads(block_threads_),
75 elements_per_group(elements_per_group_),
76 elements_per_block(elements_per_block_),
77 loop_range(loop_range_) {}
78};
79
80enum class scan_step { first, second };
81template <typename Evaluator, typename CoeffReturnType, typename OutAccessor, typename Op, typename Index,
82 scan_step stp>
83struct ScanKernelFunctor {
84 typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
85 LocalAccessor;
86 static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;
87
88 LocalAccessor scratch;
89 Evaluator dev_eval;
90 OutAccessor out_accessor;
91 OutAccessor temp_accessor;
92 const ScanParameters<Index> scanParameters;
93 Op accumulator;
94 const bool inclusive;
95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanKernelFunctor(LocalAccessor scratch_, const Evaluator dev_eval_,
96 OutAccessor out_accessor_, OutAccessor temp_accessor_,
97 const ScanParameters<Index> scanParameters_, Op accumulator_,
98 const bool inclusive_)
99 : scratch(scratch_),
100 dev_eval(dev_eval_),
101 out_accessor(out_accessor_),
102 temp_accessor(temp_accessor_),
103 scanParameters(scanParameters_),
104 accumulator(accumulator_),
105 inclusive(inclusive_) {}
106
107 template <scan_step sst = stp, typename Input>
108 typename ::Eigen::internal::enable_if<sst == scan_step::first, CoeffReturnType>::type EIGEN_DEVICE_FUNC
109 EIGEN_STRONG_INLINE
110 read(const Input &inpt, Index global_id) {
111 return inpt.coeff(global_id);
112 }
113
114 template <scan_step sst = stp, typename Input>
115 typename ::Eigen::internal::enable_if<sst != scan_step::first, CoeffReturnType>::type EIGEN_DEVICE_FUNC
116 EIGEN_STRONG_INLINE
117 read(const Input &inpt, Index global_id) {
118 return inpt[global_id];
119 }
120
121 template <scan_step sst = stp, typename InclusiveOp>
122 typename ::Eigen::internal::enable_if<sst == scan_step::first>::type EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
123 first_step_inclusive_Operation(InclusiveOp inclusive_op) {
124 inclusive_op();
125 }
126
127 template <scan_step sst = stp, typename InclusiveOp>
128 typename ::Eigen::internal::enable_if<sst != scan_step::first>::type EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
129 first_step_inclusive_Operation(InclusiveOp) {}
130
131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
132 auto out_ptr = out_accessor.get_pointer();
133 auto tmp_ptr = temp_accessor.get_pointer();
134 auto scratch_ptr = scratch.get_pointer().get();
135
136 for (Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
137 Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
138 Index tmp = data_offset % scanParameters.panel_threads;
139 const Index panel_id = data_offset / scanParameters.panel_threads;
140 const Index group_id = tmp / scanParameters.group_threads;
141 tmp = tmp % scanParameters.group_threads;
142 const Index block_id = tmp / scanParameters.block_threads;
143 const Index local_id = tmp % scanParameters.block_threads;
144 // we put one element per packet in scratch_mem
145 const Index scratch_stride = scanParameters.elements_per_block / PacketSize;
146 const Index scratch_offset = (itemID.get_local_id(0) / scanParameters.block_threads) * scratch_stride;
147 CoeffReturnType private_scan[ScanParameters<Index>::ScanPerThread];
148 CoeffReturnType inclusive_scan;
149 // the actual panel size is scan_size * non_scan_size.
150 // elements_per_panel is roundup to power of 2 for binary tree
151 const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
152 const Index group_offset = group_id * scanParameters.non_scan_stride;
153 // This will be effective when the size is bigger than elements_per_block
154 const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
155 const Index thread_offset = (ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride);
156 const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
157 Index next_elements = 0;
158 EIGEN_UNROLL_LOOP
159 for (int i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
160 Index global_id = global_offset + next_elements;
161 private_scan[i] = ((((block_id * scanParameters.elements_per_block) +
162 (ScanParameters<Index>::ScanPerThread * local_id) + i) < scanParameters.scan_size) &&
163 (global_id < scanParameters.total_size))
164 ? read(dev_eval, global_id)
165 : accumulator.initialize();
166 next_elements += scanParameters.scan_stride;
167 }
168 first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
169 if (inclusive) {
170 inclusive_scan = private_scan[ScanParameters<Index>::ScanPerThread - 1];
171 }
172 });
173 // This for loop must be 2
174 EIGEN_UNROLL_LOOP
175 for (int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
176 Index private_offset = 1;
177 // build sum in place up the tree
178 EIGEN_UNROLL_LOOP
179 for (Index d = PacketSize >> 1; d > 0; d >>= 1) {
180 EIGEN_UNROLL_LOOP
181 for (Index l = 0; l < d; l++) {
182 Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
183 Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
184 CoeffReturnType accum = accumulator.initialize();
185 accumulator.reduce(private_scan[ai], &accum);
186 accumulator.reduce(private_scan[bi], &accum);
187 private_scan[bi] = accumulator.finalize(accum);
188 }
189 private_offset *= 2;
190 }
191 scratch_ptr[2 * local_id + (packetIndex / PacketSize) + scratch_offset] =
192 private_scan[PacketSize - 1 + packetIndex];
193 private_scan[PacketSize - 1 + packetIndex] = accumulator.initialize();
194 // traverse down tree & build scan
195 EIGEN_UNROLL_LOOP
196 for (Index d = 1; d < PacketSize; d *= 2) {
197 private_offset >>= 1;
198 EIGEN_UNROLL_LOOP
199 for (Index l = 0; l < d; l++) {
200 Index ai = private_offset * (2 * l + 1) - 1 + packetIndex;
201 Index bi = private_offset * (2 * l + 2) - 1 + packetIndex;
202 CoeffReturnType accum = accumulator.initialize();
203 accumulator.reduce(private_scan[ai], &accum);
204 accumulator.reduce(private_scan[bi], &accum);
205 private_scan[ai] = private_scan[bi];
206 private_scan[bi] = accumulator.finalize(accum);
207 }
208 }
209 }
210
211 Index offset = 1;
212 // build sum in place up the tree
213 for (Index d = scratch_stride >> 1; d > 0; d >>= 1) {
214 // Synchronise
215 itemID.barrier(cl::sycl::access::fence_space::local_space);
216 if (local_id < d) {
217 Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
218 Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
219 CoeffReturnType accum = accumulator.initialize();
220 accumulator.reduce(scratch_ptr[ai], &accum);
221 accumulator.reduce(scratch_ptr[bi], &accum);
222 scratch_ptr[bi] = accumulator.finalize(accum);
223 }
224 offset *= 2;
225 }
226 // Synchronise
227 itemID.barrier(cl::sycl::access::fence_space::local_space);
228 // next step optimisation
229 if (local_id == 0) {
230 if (((scanParameters.elements_per_group / scanParameters.elements_per_block) > 1)) {
231 const Index temp_id = panel_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) *
232 scanParameters.non_scan_size +
233 group_id * (scanParameters.elements_per_group / scanParameters.elements_per_block) +
234 block_id;
235 tmp_ptr[temp_id] = scratch_ptr[scratch_stride - 1 + scratch_offset];
236 }
237 // clear the last element
238 scratch_ptr[scratch_stride - 1 + scratch_offset] = accumulator.initialize();
239 }
240 // traverse down tree & build scan
241 for (Index d = 1; d < scratch_stride; d *= 2) {
242 offset >>= 1;
243 // Synchronise
244 itemID.barrier(cl::sycl::access::fence_space::local_space);
245 if (local_id < d) {
246 Index ai = offset * (2 * local_id + 1) - 1 + scratch_offset;
247 Index bi = offset * (2 * local_id + 2) - 1 + scratch_offset;
248 CoeffReturnType accum = accumulator.initialize();
249 accumulator.reduce(scratch_ptr[ai], &accum);
250 accumulator.reduce(scratch_ptr[bi], &accum);
251 scratch_ptr[ai] = scratch_ptr[bi];
252 scratch_ptr[bi] = accumulator.finalize(accum);
253 }
254 }
255 // Synchronise
256 itemID.barrier(cl::sycl::access::fence_space::local_space);
257 // This for loop must be 2
258 EIGEN_UNROLL_LOOP
259 for (int packetIndex = 0; packetIndex < ScanParameters<Index>::ScanPerThread; packetIndex += PacketSize) {
260 EIGEN_UNROLL_LOOP
261 for (Index i = 0; i < PacketSize; i++) {
262 CoeffReturnType accum = private_scan[packetIndex + i];
263 accumulator.reduce(scratch_ptr[2 * local_id + (packetIndex / PacketSize) + scratch_offset], &accum);
264 private_scan[packetIndex + i] = accumulator.finalize(accum);
265 }
266 }
267 first_step_inclusive_Operation([&]() EIGEN_DEVICE_FUNC {
268 if (inclusive) {
269 accumulator.reduce(private_scan[ScanParameters<Index>::ScanPerThread - 1], &inclusive_scan);
270 private_scan[0] = accumulator.finalize(inclusive_scan);
271 }
272 });
273 next_elements = 0;
274 // right the first set of private param
275 EIGEN_UNROLL_LOOP
276 for (Index i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
277 Index global_id = global_offset + next_elements;
278 if ((((block_id * scanParameters.elements_per_block) + (ScanParameters<Index>::ScanPerThread * local_id) + i) <
279 scanParameters.scan_size) &&
280 (global_id < scanParameters.total_size)) {
281 Index private_id = (i * !inclusive) + (((i + 1) % ScanParameters<Index>::ScanPerThread) * (inclusive));
282 out_ptr[global_id] = private_scan[private_id];
283 }
284 next_elements += scanParameters.scan_stride;
285 }
286 } // end for loop
287 }
288};
289
290template <typename CoeffReturnType, typename InAccessor, typename OutAccessor, typename Op, typename Index>
291struct ScanAdjustmentKernelFunctor {
292 typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local>
293 LocalAccessor;
294 static EIGEN_CONSTEXPR int PacketSize = ScanParameters<Index>::ScanPerThread / 2;
295 InAccessor in_accessor;
296 OutAccessor out_accessor;
297 const ScanParameters<Index> scanParameters;
298 Op accumulator;
299 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScanAdjustmentKernelFunctor(LocalAccessor, InAccessor in_accessor_,
300 OutAccessor out_accessor_,
301 const ScanParameters<Index> scanParameters_,
302 Op accumulator_)
303 : in_accessor(in_accessor_),
304 out_accessor(out_accessor_),
305 scanParameters(scanParameters_),
306 accumulator(accumulator_) {}
307
308 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
309 auto in_ptr = in_accessor.get_pointer();
310 auto out_ptr = out_accessor.get_pointer();
311
312 for (Index loop_offset = 0; loop_offset < scanParameters.loop_range; loop_offset++) {
313 Index data_offset = (itemID.get_global_id(0) + (itemID.get_global_range(0) * loop_offset));
314 Index tmp = data_offset % scanParameters.panel_threads;
315 const Index panel_id = data_offset / scanParameters.panel_threads;
316 const Index group_id = tmp / scanParameters.group_threads;
317 tmp = tmp % scanParameters.group_threads;
318 const Index block_id = tmp / scanParameters.block_threads;
319 const Index local_id = tmp % scanParameters.block_threads;
320
321 // the actual panel size is scan_size * non_scan_size.
322 // elements_per_panel is roundup to power of 2 for binary tree
323 const Index panel_offset = panel_id * scanParameters.scan_size * scanParameters.non_scan_size;
324 const Index group_offset = group_id * scanParameters.non_scan_stride;
325 // This will be effective when the size is bigger than elements_per_block
326 const Index block_offset = block_id * scanParameters.elements_per_block * scanParameters.scan_stride;
327 const Index thread_offset = ScanParameters<Index>::ScanPerThread * local_id * scanParameters.scan_stride;
328
329 const Index global_offset = panel_offset + group_offset + block_offset + thread_offset;
330 const Index block_size = scanParameters.elements_per_group / scanParameters.elements_per_block;
331 const Index in_id = (panel_id * block_size * scanParameters.non_scan_size) + (group_id * block_size) + block_id;
332 CoeffReturnType adjust_val = in_ptr[in_id];
333
334 Index next_elements = 0;
335 EIGEN_UNROLL_LOOP
336 for (Index i = 0; i < ScanParameters<Index>::ScanPerThread; i++) {
337 Index global_id = global_offset + next_elements;
338 if ((((block_id * scanParameters.elements_per_block) + (ScanParameters<Index>::ScanPerThread * local_id) + i) <
339 scanParameters.scan_size) &&
340 (global_id < scanParameters.total_size)) {
341 CoeffReturnType accum = adjust_val;
342 accumulator.reduce(out_ptr[global_id], &accum);
343 out_ptr[global_id] = accumulator.finalize(accum);
344 }
345 next_elements += scanParameters.scan_stride;
346 }
347 }
348 }
349};
350
351template <typename Index>
352struct ScanInfo {
353 const Index &total_size;
354 const Index &scan_size;
355 const Index &panel_size;
356 const Index &non_scan_size;
357 const Index &scan_stride;
358 const Index &non_scan_stride;
359
360 Index max_elements_per_block;
361 Index block_size;
362 Index panel_threads;
363 Index group_threads;
364 Index block_threads;
365 Index elements_per_group;
366 Index elements_per_block;
367 Index loop_range;
368 Index global_range;
369 Index local_range;
370 const Eigen::SyclDevice &dev;
371 EIGEN_STRONG_INLINE ScanInfo(const Index &total_size_, const Index &scan_size_, const Index &panel_size_,
372 const Index &non_scan_size_, const Index &scan_stride_, const Index &non_scan_stride_,
373 const Eigen::SyclDevice &dev_)
374 : total_size(total_size_),
375 scan_size(scan_size_),
376 panel_size(panel_size_),
377 non_scan_size(non_scan_size_),
378 scan_stride(scan_stride_),
379 non_scan_stride(non_scan_stride_),
380 dev(dev_) {
381 // must be power of 2
382 local_range = std::min(Index(dev.getNearestPowerOfTwoWorkGroupSize()),
383 Index(EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1));
384
385 max_elements_per_block = local_range * ScanParameters<Index>::ScanPerThread;
386
387 elements_per_group =
388 dev.getPowerOfTwo(Index(roundUp(Index(scan_size), ScanParameters<Index>::ScanPerThread)), true);
389 const Index elements_per_panel = elements_per_group * non_scan_size;
390 elements_per_block = std::min(Index(elements_per_group), Index(max_elements_per_block));
391 panel_threads = elements_per_panel / ScanParameters<Index>::ScanPerThread;
392 group_threads = elements_per_group / ScanParameters<Index>::ScanPerThread;
393 block_threads = elements_per_block / ScanParameters<Index>::ScanPerThread;
394 block_size = elements_per_group / elements_per_block;
395#ifdef EIGEN_SYCL_MAX_GLOBAL_RANGE
396 const Index max_threads = std::min(Index(panel_threads * panel_size), Index(EIGEN_SYCL_MAX_GLOBAL_RANGE));
397#else
398 const Index max_threads = panel_threads * panel_size;
399#endif
400 global_range = roundUp(max_threads, local_range);
401 loop_range = Index(
402 std::ceil(double(elements_per_panel * panel_size) / (global_range * ScanParameters<Index>::ScanPerThread)));
403 }
404 inline ScanParameters<Index> get_scan_parameter() {
405 return ScanParameters<Index>(total_size, non_scan_size, scan_size, non_scan_stride, scan_stride, panel_threads,
406 group_threads, block_threads, elements_per_group, elements_per_block, loop_range);
407 }
408 inline cl::sycl::nd_range<1> get_thread_range() {
409 return cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
410 }
411};
412
413template <typename EvaluatorPointerType, typename CoeffReturnType, typename Reducer, typename Index>
414struct SYCLAdjustBlockOffset {
415 EIGEN_STRONG_INLINE static void adjust_scan_block_offset(EvaluatorPointerType in_ptr, EvaluatorPointerType out_ptr,
416 Reducer &accumulator, const Index total_size,
417 const Index scan_size, const Index panel_size,
418 const Index non_scan_size, const Index scan_stride,
419 const Index non_scan_stride, const Eigen::SyclDevice &dev) {
420 auto scan_info =
421 ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);
422
423 typedef ScanAdjustmentKernelFunctor<CoeffReturnType, EvaluatorPointerType, EvaluatorPointerType, Reducer, Index>
424 AdjustFuctor;
425 dev.template unary_kernel_launcher<CoeffReturnType, AdjustFuctor>(in_ptr, out_ptr, scan_info.get_thread_range(),
426 scan_info.max_elements_per_block,
427 scan_info.get_scan_parameter(), accumulator);
428 }
429};
430
431template <typename CoeffReturnType, scan_step stp>
432struct ScanLauncher_impl {
433 template <typename Input, typename EvaluatorPointerType, typename Reducer, typename Index>
434 EIGEN_STRONG_INLINE static void scan_block(Input in_ptr, EvaluatorPointerType out_ptr, Reducer &accumulator,
435 const Index total_size, const Index scan_size, const Index panel_size,
436 const Index non_scan_size, const Index scan_stride,
437 const Index non_scan_stride, const bool inclusive,
438 const Eigen::SyclDevice &dev) {
439 auto scan_info =
440 ScanInfo<Index>(total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride, dev);
441 const Index temp_pointer_size = scan_info.block_size * non_scan_size * panel_size;
442 const Index scratch_size = scan_info.max_elements_per_block / (ScanParameters<Index>::ScanPerThread / 2);
443 CoeffReturnType *temp_pointer =
444 static_cast<CoeffReturnType *>(dev.allocate_temp(temp_pointer_size * sizeof(CoeffReturnType)));
445 EvaluatorPointerType tmp_global_accessor = dev.get(temp_pointer);
446
447 typedef ScanKernelFunctor<Input, CoeffReturnType, EvaluatorPointerType, Reducer, Index, stp> ScanFunctor;
448 dev.template binary_kernel_launcher<CoeffReturnType, ScanFunctor>(
449 in_ptr, out_ptr, tmp_global_accessor, scan_info.get_thread_range(), scratch_size,
450 scan_info.get_scan_parameter(), accumulator, inclusive);
451
452 if (scan_info.block_size > 1) {
453 ScanLauncher_impl<CoeffReturnType, scan_step::second>::scan_block(
454 tmp_global_accessor, tmp_global_accessor, accumulator, temp_pointer_size, scan_info.block_size, panel_size,
455 non_scan_size, Index(1), scan_info.block_size, false, dev);
456
457 SYCLAdjustBlockOffset<EvaluatorPointerType, CoeffReturnType, Reducer, Index>::adjust_scan_block_offset(
458 tmp_global_accessor, out_ptr, accumulator, total_size, scan_size, panel_size, non_scan_size, scan_stride,
459 non_scan_stride, dev);
460 }
461 dev.deallocate_temp(temp_pointer);
462 }
463};
464
465} // namespace internal
466} // namespace TensorSycl
467namespace internal {
468template <typename Self, typename Reducer, bool vectorize>
469struct ScanLauncher<Self, Reducer, Eigen::SyclDevice, vectorize> {
470 typedef typename Self::Index Index;
471 typedef typename Self::CoeffReturnType CoeffReturnType;
472 typedef typename Self::Storage Storage;
473 typedef typename Self::EvaluatorPointerType EvaluatorPointerType;
474 void operator()(Self &self, EvaluatorPointerType data) {
475 const Index total_size = internal::array_prod(self.dimensions());
476 const Index scan_size = self.size();
477 const Index scan_stride = self.stride();
478 // this is the scan op (can be sum or ...)
479 auto accumulator = self.accumulator();
480 auto inclusive = !self.exclusive();
481 auto consume_dim = self.consume_dim();
482 auto dev = self.device();
483
484 auto dims = self.inner().dimensions();
485
486 Index non_scan_size = 1;
487 Index panel_size = 1;
488 if (static_cast<int>(Self::Layout) == static_cast<int>(ColMajor)) {
489 for (int i = 0; i < consume_dim; i++) {
490 non_scan_size *= dims[i];
491 }
492 for (int i = consume_dim + 1; i < Self::NumDims; i++) {
493 panel_size *= dims[i];
494 }
495 } else {
496 for (int i = Self::NumDims - 1; i > consume_dim; i--) {
497 non_scan_size *= dims[i];
498 }
499 for (int i = consume_dim - 1; i >= 0; i--) {
500 panel_size *= dims[i];
501 }
502 }
503 const Index non_scan_stride = (scan_stride > 1) ? 1 : scan_size;
504 auto eval_impl = self.inner();
505 TensorSycl::internal::ScanLauncher_impl<CoeffReturnType, TensorSycl::internal::scan_step::first>::scan_block(
506 eval_impl, data, accumulator, total_size, scan_size, panel_size, non_scan_size, scan_stride, non_scan_stride,
507 inclusive, dev);
508 }
509};
510} // namespace internal
511} // namespace Eigen
512
513#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_SYCL_SYCL_HPP
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index