23#if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
28#ifdef EIGEN_EXCEPTIONS
34#include <unordered_map>
40using sycl_acc_target = cl::sycl::access::target;
41using sycl_acc_mode = cl::sycl::access::mode;
46using buffer_data_type_t = uint8_t;
47const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
57 using base_ptr_t = std::intptr_t;
65 struct virtual_pointer_t {
68 base_ptr_t m_contents;
73 operator void *()
const {
return reinterpret_cast<void *
>(m_contents); }
78 operator base_ptr_t()
const {
return m_contents; }
84 virtual_pointer_t operator+(
size_t off) {
return m_contents + off; }
87 bool operator<(virtual_pointer_t rhs)
const {
88 return (
static_cast<base_ptr_t
>(m_contents) <
89 static_cast<base_ptr_t
>(rhs.m_contents));
92 bool operator>(virtual_pointer_t rhs)
const {
93 return (
static_cast<base_ptr_t
>(m_contents) >
94 static_cast<base_ptr_t
>(rhs.m_contents));
100 bool operator==(virtual_pointer_t rhs)
const {
101 return (
static_cast<base_ptr_t
>(m_contents) ==
102 static_cast<base_ptr_t
>(rhs.m_contents));
108 bool operator!=(virtual_pointer_t rhs)
const {
109 return !(this->operator==(rhs));
118 virtual_pointer_t(
const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
130 const virtual_pointer_t null_virtual_ptr =
nullptr;
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (
static_cast<void *
>(ptr) ==
nullptr);
142 using buffer_t = cl::sycl::buffer_mem;
154 pMapNode_t(buffer_t b,
size_t size,
bool f)
155 : m_buffer{b}, m_size{size}, m_free{f} {
156 m_buffer.set_final_data(
nullptr);
159 bool operator<=(
const pMapNode_t &rhs) {
return (m_size <= rhs.m_size); }
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
171 typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
174 if (!m_freeList.empty()) {
176 for (
auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
181 m_freeList.erase(freeElem);
187 retVal = std::prev(m_pointerMap.end());
202 typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
214 auto node = m_pointerMap.lower_bound(ptr);
217 if (node == std::end(m_pointerMap)) {
219 }
else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
223 std::out_of_range(
"The pointer is not registered in the map\n"));
235 template <
typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
244 auto node = get_node(ptr);
245 eigen_assert(node->first == ptr || node->first < ptr);
246 eigen_assert(ptr <
static_cast<virtual_pointer_t
>(node->second.m_size +
248 return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(
const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
286 inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
289 auto node = get_node(ptr);
290 auto start = node->first;
291 eigen_assert(start == ptr || start < ptr);
292 eigen_assert(ptr < start + node->second.m_size);
293 return (ptr - start);
300 template <
typename buffer_data_type>
301 inline size_t get_element_offset(
const virtual_pointer_t ptr) {
302 return get_offset(ptr) /
sizeof(buffer_data_type);
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
318 PointerMapper(
const PointerMapper &) =
delete;
323 inline void clear() {
325 m_pointerMap.clear();
331 inline virtual_pointer_t add_pointer(
const buffer_t &b) {
332 return add_pointer_impl(b);
338 inline virtual_pointer_t add_pointer(buffer_t &&b) {
339 return add_pointer_impl(b);
348 void fuse_forward(
typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
360 node->second.m_size += fwd_size;
370 void fuse_backward(
typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
378 prev_node->second.m_size += node->second.m_size;
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
393 template <
bool ReUse = true>
394 void remove_pointer(
const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
398 auto node = this->get_node(ptr);
400 node->second.m_free =
true;
401 m_freeList.emplace(node);
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
420 size_t count()
const {
return (m_pointerMap.size() - m_freeList.size()); }
427 template <
class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT b) {
429 virtual_pointer_t retVal =
nullptr;
430 size_t bufSize = b.get_count();
431 pMapNode_t p{b, bufSize,
false};
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal, p);
439 auto lastElemIter = get_insertion_point(bufSize);
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer = b;
443 lastElemIter->second.m_free =
false;
447 if (lastElemIter->second.m_size > bufSize) {
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t p2{b, remainingSize,
true};
453 lastElemIter->second.m_size = bufSize;
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458 m_freeList.emplace(freeNode);
461 retVal = lastElemIter->first;
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal, p);
475 bool operator()(
typename pointerMap_t::iterator a,
476 typename pointerMap_t::iterator b)
const {
477 return ((a->first < b->first) && (a->second <= b->second)) ||
478 ((a->first < b->first) && (b->second <= a->second));
484 pointerMap_t m_pointerMap;
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
492 std::intptr_t m_baseAddress;
500inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
504 m_pointerMap.erase(this->get_node(ptr));
514inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
522 return static_cast<void *
>(thePointer);
532template <
bool ReUse = true,
typename Po
interMapper>
533inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
540template <
typename Po
interMapper>
541inline void SYCLfreeAll(PointerMapper &pMap) {
545template <cl::sycl::access::mode AcMd,
typename T>
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
557 typedef RangeAccess<AcMd, T> self_t;
558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
560 std::intptr_t virtual_ptr)
561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
570 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer()
const {
571 return (access_.get_pointer().get() + offset_);
573 template <
typename Index>
574 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
578 template <
typename Index>
579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset)
const {
580 return self_t(access_, offset_ + offset, virtual_ptr_);
582 template <
typename Index>
583 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset)
const {
584 return self_t(access_, offset_ - offset, virtual_ptr_);
586 template <
typename Index>
587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
friend bool operator==(
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
597 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
friend bool operator!=(
598 const RangeAccess &lhs, std::nullptr_t i) {
603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
friend bool operator==(
604 std::nullptr_t,
const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
607 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
friend bool operator!=(
608 std::nullptr_t i,
const RangeAccess &rhs) {
612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(
int i) {
619 EIGEN_UNUSED_VARIABLE(i);
620 self_t temp_iterator(*
this);
622 return temp_iterator;
625 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size()
const {
626 return (access_.get_count() - offset_);
629 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset()
const {
633 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void set_offset(std::ptrdiff_t offset) {
637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*()
const {
638 return *get_pointer();
641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
642 return *get_pointer();
645 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() =
delete;
647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](
int x) {
648 return *(get_pointer() + x);
651 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](
int x)
const {
652 return *(get_pointer() + x);
655 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer()
const {
656 return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
657 (offset_ *
sizeof(scalar_t)));
660 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
explicit operator bool()
const {
661 return (virtual_ptr_ != -1);
664 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
operator RangeAccess<AcMd, const T>() {
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
668 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
669 operator RangeAccess<AcMd, const T>()
const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
673 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(
674 cl::sycl::handler &cgh)
const {
675 cgh.require(access_);
681 std::intptr_t virtual_ptr_;
684template <cl::sycl::access::mode AcMd,
typename T>
685struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T> Base;
Namespace containing all symbols from the Eigen library.
Definition: Core:141