Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
 
Loading...
Searching...
No Matches
SyclMemoryModel.h
1/***************************************************************************
2 * Copyright (C) 2017 Codeplay Software Limited
3 * This Source Code Form is subject to the terms of the Mozilla
4 * Public License v. 2.0. If a copy of the MPL was not distributed
5 * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 *
8 * SyclMemoryModel.h
9 *
10 * Description:
11 * Interface for SYCL buffers to behave as a non-dereferenceable pointer
12 * Interface for Placeholder accessor to behave as a pointer on both host
13 * and device
14 *
15 * Authors:
16 *
17 * Ruyman Reyes Codeplay Software Ltd.
18 * Mehdi Goli Codeplay Software Ltd.
19 * Vanya Yaneva Codeplay Software Ltd.
20 *
21 **************************************************************************/
22
23#if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
26
27#include <CL/sycl.hpp>
28#ifdef EIGEN_EXCEPTIONS
29#include <stdexcept>
30#endif
31#include <cstddef>
32#include <queue>
33#include <set>
34#include <unordered_map>
35
36namespace Eigen {
37namespace TensorSycl {
38namespace internal {
39
40using sycl_acc_target = cl::sycl::access::target;
41using sycl_acc_mode = cl::sycl::access::mode;
42
46using buffer_data_type_t = uint8_t;
47const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
49
55class PointerMapper {
56 public:
57 using base_ptr_t = std::intptr_t;
58
59 /* Structure of a virtual pointer
60 *
61 * |================================================|
62 * | POINTER ADDRESS |
63 * |================================================|
64 */
65 struct virtual_pointer_t {
66 /* Type for the pointers
67 */
68 base_ptr_t m_contents;
69
73 operator void *() const { return reinterpret_cast<void *>(m_contents); }
74
78 operator base_ptr_t() const { return m_contents; }
79
84 virtual_pointer_t operator+(size_t off) { return m_contents + off; }
85
86 /* Numerical order for sorting pointers in containers. */
87 bool operator<(virtual_pointer_t rhs) const {
88 return (static_cast<base_ptr_t>(m_contents) <
89 static_cast<base_ptr_t>(rhs.m_contents));
90 }
91
92 bool operator>(virtual_pointer_t rhs) const {
93 return (static_cast<base_ptr_t>(m_contents) >
94 static_cast<base_ptr_t>(rhs.m_contents));
95 }
96
100 bool operator==(virtual_pointer_t rhs) const {
101 return (static_cast<base_ptr_t>(m_contents) ==
102 static_cast<base_ptr_t>(rhs.m_contents));
103 }
104
108 bool operator!=(virtual_pointer_t rhs) const {
109 return !(this->operator==(rhs));
110 }
111
118 virtual_pointer_t(const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
120
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
126 };
127
128 /* Definition of a null pointer
129 */
130 const virtual_pointer_t null_virtual_ptr = nullptr;
131
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (static_cast<void *>(ptr) == nullptr);
138 }
139
140 /* basic type for all buffers
141 */
142 using buffer_t = cl::sycl::buffer_mem;
143
149 struct pMapNode_t {
150 buffer_t m_buffer;
151 size_t m_size;
152 bool m_free;
153
154 pMapNode_t(buffer_t b, size_t size, bool f)
155 : m_buffer{b}, m_size{size}, m_free{f} {
156 m_buffer.set_final_data(nullptr);
157 }
158
159 bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
160 };
161
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
165
171 typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
173 bool reuse = false;
174 if (!m_freeList.empty()) {
175 // try to re-use an existing block
176 for (auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
178 retVal = freeElem;
179 reuse = true;
180 // Element is not going to be free anymore
181 m_freeList.erase(freeElem);
182 break;
183 }
184 }
185 }
186 if (!reuse) {
187 retVal = std::prev(m_pointerMap.end());
188 }
189 return retVal;
190 }
191
202 typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
206
207 }
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
211 }
212 // The previous element to the lower bound is the node that
213 // holds this memory address
214 auto node = m_pointerMap.lower_bound(ptr);
215 // If the value of the pointer is not the one of the node
216 // then we return the previous one
217 if (node == std::end(m_pointerMap)) {
218 --node;
219 } else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
222 EIGEN_THROW_X(
223 std::out_of_range("The pointer is not registered in the map\n"));
224
225 }
226 --node;
227 }
228
229 return node;
230 }
231
232 /* get_buffer.
233 * Returns a buffer from the map using the pointer address
234 */
235 template <typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
239
240 // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
241 // We can do this without the `buffer_mem` being a pointer, as we
242 // only declare member variables in the base class (`buffer_mem`) and not in
243 // the child class (`buffer<>).
244 auto node = get_node(ptr);
245 eigen_assert(node->first == ptr || node->first < ptr);
246 eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
247 node->first));
248 return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
249 }
250
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
264 }
265
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
281 }
282
283 /*
284 * Returns the offset from the base address of this pointer.
285 */
286 inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
287 // The previous element to the lower bound is the node that
288 // holds this memory address
289 auto node = get_node(ptr);
290 auto start = node->first;
291 eigen_assert(start == ptr || start < ptr);
292 eigen_assert(ptr < start + node->second.m_size);
293 return (ptr - start);
294 }
295
296 /*
297 * Returns the number of elements by which the given pointer is offset from
298 * the base address.
299 */
300 template <typename buffer_data_type>
301 inline size_t get_element_offset(const virtual_pointer_t ptr) {
302 return get_offset(ptr) / sizeof(buffer_data_type);
303 }
304
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
312 }
313 };
314
318 PointerMapper(const PointerMapper &) = delete;
319
323 inline void clear() {
324 m_freeList.clear();
325 m_pointerMap.clear();
326 }
327
328 /* add_pointer.
329 * Adds an existing pointer to the map and returns the virtual pointer id.
330 */
331 inline virtual_pointer_t add_pointer(const buffer_t &b) {
332 return add_pointer_impl(b);
333 }
334
335 /* add_pointer.
336 * Adds a pointer to the map and returns the virtual pointer id.
337 */
338 inline virtual_pointer_t add_pointer(buffer_t &&b) {
339 return add_pointer_impl(b);
340 }
341
348 void fuse_forward(typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
350 // if following node is free
351 // remove it and extend the current node with its size
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
354 break;
355 }
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
359
360 node->second.m_size += fwd_size;
361 }
362 }
363
370 void fuse_backward(typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
372 // if previous node is free, extend it
373 // with the size of the current one
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
376 break;
377 }
378 prev_node->second.m_size += node->second.m_size;
379
380 // remove the current node
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
383
384 // point to the previous node
385 node = prev_node;
386 }
387 }
388
389 /* remove_pointer.
390 * Removes the given pointer from the map.
391 * The pointer is allowed to be reused only if ReUse if true.
392 */
393 template <bool ReUse = true>
394 void remove_pointer(const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
396 return;
397 }
398 auto node = this->get_node(ptr);
399
400 node->second.m_free = true;
401 m_freeList.emplace(node);
402
403 // Fuse the node
404 // with free nodes before and after it
405 fuse_forward(node);
406 fuse_backward(node);
407
408 // If after fusing the node is the last one
409 // simply remove it (since it is free)
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
413 }
414 }
415
416 /* count.
417 * Return the number of active pointers (i.e, pointers that
418 * have been malloc but not freed).
419 */
420 size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
421
422 private:
423 /* add_pointer_impl.
424 * Adds a pointer to the map and returns the virtual pointer id.
425 * BufferT is either a const buffer_t& or a buffer_t&&.
426 */
427 template <class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT b) {
429 virtual_pointer_t retVal = nullptr;
430 size_t bufSize = b.get_count();
431 pMapNode_t p{b, bufSize, false};
432 // If this is the first pointer:
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal, p);
436 return initialVal;
437 }
438
439 auto lastElemIter = get_insertion_point(bufSize);
440 // We are recovering an existing free node
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer = b;
443 lastElemIter->second.m_free = false;
444
445 // If the recovered node is bigger than the inserted one
446 // add a new free node with the remaining space
447 if (lastElemIter->second.m_size > bufSize) {
448 // create a new node with the remaining space
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t p2{b, remainingSize, true};
451
452 // update size of the current node
453 lastElemIter->second.m_size = bufSize;
454
455 // add the new free node
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458 m_freeList.emplace(freeNode);
459 }
460
461 retVal = lastElemIter->first;
462 } else {
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal, p);
466 }
467 return retVal;
468 }
469
474 struct SortBySize {
475 bool operator()(typename pointerMap_t::iterator a,
476 typename pointerMap_t::iterator b) const {
477 return ((a->first < b->first) && (a->second <= b->second)) ||
478 ((a->first < b->first) && (b->second <= a->second));
479 }
480 };
481
482 /* Maps the pointer addresses to buffer and size pairs.
483 */
484 pointerMap_t m_pointerMap;
485
486 /* List of free nodes available for re-using
487 */
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
489
490 /* Base address used when issuing the first virtual pointer, allows users
491 * to specify alignment. Cannot be zero. */
492 std::intptr_t m_baseAddress;
493};
494
495/* remove_pointer.
496 * Removes the given pointer from the map.
497 * The pointer is allowed to be reused only if ReUse if true.
498 */
499template <>
500inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
502 return;
503 }
504 m_pointerMap.erase(this->get_node(ptr));
505}
506
514inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
515 if (size == 0) {
516 return nullptr;
517 }
518 // Create a generic buffer of the given size
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
521 // Store the buffer on the global list
522 return static_cast<void *>(thePointer);
523}
524
532template <bool ReUse = true, typename PointerMapper>
533inline void SYCLfree(void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
535}
536
540template <typename PointerMapper>
541inline void SYCLfreeAll(PointerMapper &pMap) {
542 pMap.clear();
543}
544
545template <cl::sycl::access::mode AcMd, typename T>
546struct RangeAccess {
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
549 typedef T scalar_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
552
553 // the accessor type does not necessarily the same as T
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
555 accessor;
556
557 typedef RangeAccess<AcMd, T> self_t;
558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
559 size_t offset,
560 std::intptr_t virtual_ptr)
561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
562
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
566
567 // This should be only used for null constructor on the host side
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
569 // This template parameter must be removed and scalar_t should be replaced
570 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
571 return (access_.get_pointer().get() + offset_);
572 }
573 template <typename Index>
574 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
575 offset_ += (offset);
576 return *this;
577 }
578 template <typename Index>
579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
580 return self_t(access_, offset_ + offset, virtual_ptr_);
581 }
582 template <typename Index>
583 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
584 return self_t(access_, offset_ - offset, virtual_ptr_);
585 }
586 template <typename Index>
587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
588 offset_ -= offset;
589 return *this;
590 }
591
592 // THIS IS FOR NULL COMPARISON ONLY
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
596 }
597 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
598 const RangeAccess &lhs, std::nullptr_t i) {
599 return !(lhs == i);
600 }
601
602 // THIS IS FOR NULL COMPARISON ONLY
603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
604 std::nullptr_t, const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
606 }
607 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
608 std::nullptr_t i, const RangeAccess &rhs) {
609 return !(i == rhs);
610 }
611 // Prefix operator (Increment and return value)
612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
613 offset_++;
614 return (*this);
615 }
616
617 // Postfix operator (Return value and increment)
618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
619 EIGEN_UNUSED_VARIABLE(i);
620 self_t temp_iterator(*this);
621 offset_++;
622 return temp_iterator;
623 }
624
625 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
626 return (access_.get_count() - offset_);
627 }
628
629 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
630 return offset_;
631 }
632
633 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
634 offset_ = offset;
635 }
636
637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
638 return *get_pointer();
639 }
640
641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
642 return *get_pointer();
643 }
644
645 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
646
647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
648 return *(get_pointer() + x);
649 }
650
651 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
652 return *(get_pointer() + x);
653 }
654
655 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
656 return reinterpret_cast<scalar_t *>(virtual_ptr_ +
657 (offset_ * sizeof(scalar_t)));
658 }
659
660 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
661 return (virtual_ptr_ != -1);
662 }
663
664 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
666 }
667
668 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
669 operator RangeAccess<AcMd, const T>() const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
671 }
672 // binding placeholder accessors to a command group handler for SYCL
673 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
674 cl::sycl::handler &cgh) const {
675 cgh.require(access_);
676 }
677
678 private:
679 accessor access_;
680 size_t offset_;
681 std::intptr_t virtual_ptr_; // the location of the buffer in the map
682};
683
684template <cl::sycl::access::mode AcMd, typename T>
685struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T> Base;
687 using Base::Base;
688};
689
690} // namespace internal
691} // namespace TensorSycl
692} // namespace Eigen
693
694#endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
Namespace containing all symbols from the Eigen library.
Definition: Core:141