Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorContractionMapper.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
12
13namespace Eigen {
14
15namespace internal {
16
17enum {
18 Rhs = 0,
19 Lhs = 1
20};
21
22/*
23 * Implementation of the Eigen blas_data_mapper class for tensors.
24 */
27template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
28struct CoeffLoader;
29
30template <typename Scalar, typename Index, int side, typename Tensor,
31 typename nocontract_t, typename contract_t, int packet_size,
32 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
33 template <class> class MakePointer_ = MakePointer>
34class BaseTensorContractionMapper;
35
36template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
37struct CoeffLoader {
38 enum {
39 DirectOffsets = false
40 };
41
42 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
43
44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
45 eigen_assert(false && "unsupported");
46 }
47
48 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
49 data() const {
50 eigen_assert(false && "unsupported");
51 return NULL;
52 }
53
54 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
55
56 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
57 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
58 {
59 return m_tensor.template packet<LoadMode>(index);
60 }
61
62 #ifdef EIGEN_USE_SYCL
63 // The placeholder accessors require to be bound to a command group handler for SYCL
64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
65 m_tensor.bind(cgh);
66 }
67 #endif
68
69 private:
70 const Tensor m_tensor;
71};
72
73template <typename Tensor, template <class> class MakePointer_>
74struct CoeffLoader<Tensor, true, MakePointer_> {
75 enum {
76 DirectOffsets = true
77 };
78
79 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
80
81 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
82 m_data += offset;
83 }
84
85 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
86 data() const {
87 return m_data;
88 }
89
90 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
91
92 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
93 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
94 {
95 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
96 }
97
98 #ifdef EIGEN_USE_SYCL
99 // The placeholder accessors require to be bound to a command group handler for SYCL
100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
101 m_data.bind(cgh);
102 }
103 #endif
104 private:
105 typedef typename Tensor::Scalar Scalar;
106
107 typename MakePointer_<const Scalar>::Type m_data;
108};
109
110template<typename Scalar, typename Index, int side,
111 typename Tensor,
112 typename nocontract_t, typename contract_t,
113 int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
114class SimpleTensorContractionMapper {
115 public:
116 EIGEN_DEVICE_FUNC
117 SimpleTensorContractionMapper(const Tensor& tensor,
118 const nocontract_t& nocontract_strides,
119 const nocontract_t& ij_strides,
120 const contract_t& contract_strides,
121 const contract_t& k_strides) :
122 m_tensor(tensor),
123 m_nocontract_strides(nocontract_strides),
124 m_ij_strides(ij_strides),
125 m_contract_strides(contract_strides),
126 m_k_strides(k_strides) { }
127
128 enum {
129 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets
130 };
131
132 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
133 m_tensor.offsetBuffer(offset);
134 }
135
136 EIGEN_DEVICE_FUNC
137 EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
138
139 EIGEN_DEVICE_FUNC
140 EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
141 // column major assumption
142 return operator()(row, 0);
143 }
144
145 EIGEN_DEVICE_FUNC
146 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
147 return m_tensor.coeff(computeIndex(row, col));
148 }
149
150 EIGEN_DEVICE_FUNC
151 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
152 const bool left = (side == Lhs);
153 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
154 Index nocontract_val = left ? row : col;
155 Index linidx = 0;
156 EIGEN_UNROLL_LOOP
157 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
158 const Index idx = nocontract_val / m_ij_strides[i];
159 linidx += idx * m_nocontract_strides[i];
160 nocontract_val -= idx * m_ij_strides[i];
161 }
162 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
163 if (side == Lhs && inner_dim_contiguous) {
164 eigen_assert(m_nocontract_strides[0] == 1);
165 linidx += nocontract_val;
166 } else {
167 linidx += nocontract_val * m_nocontract_strides[0];
168 }
169 }
170
171 Index contract_val = left ? col : row;
172 if(array_size<contract_t>::value > 0) {
173 EIGEN_UNROLL_LOOP
174 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
175 const Index idx = contract_val / m_k_strides[i];
176 linidx += idx * m_contract_strides[i];
177 contract_val -= idx * m_k_strides[i];
178 }
179
180 if (side == Rhs && inner_dim_contiguous) {
181 eigen_assert(m_contract_strides[0] == 1);
182 linidx += contract_val;
183 } else {
184 linidx += contract_val * m_contract_strides[0];
185 }
186 }
187
188 return linidx;
189 }
190
191 EIGEN_DEVICE_FUNC
192 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
193 const bool left = (side == Lhs);
194 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
195 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
196 Index linidx[2] = {0, 0};
197 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
198 EIGEN_UNROLL_LOOP
199 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
200 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
201 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
202 linidx[0] += idx0 * m_nocontract_strides[i];
203 linidx[1] += idx1 * m_nocontract_strides[i];
204 nocontract_val[0] -= idx0 * m_ij_strides[i];
205 nocontract_val[1] -= idx1 * m_ij_strides[i];
206 }
207 if (side == Lhs && inner_dim_contiguous) {
208 eigen_assert(m_nocontract_strides[0] == 1);
209 linidx[0] += nocontract_val[0];
210 linidx[1] += nocontract_val[1];
211 } else {
212 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
213 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
214 }
215 }
216
217 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
218 if (array_size<contract_t>::value> 0) {
219 EIGEN_UNROLL_LOOP
220 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
221 const Index idx0 = contract_val[0] / m_k_strides[i];
222 const Index idx1 = contract_val[1] / m_k_strides[i];
223 linidx[0] += idx0 * m_contract_strides[i];
224 linidx[1] += idx1 * m_contract_strides[i];
225 contract_val[0] -= idx0 * m_k_strides[i];
226 contract_val[1] -= idx1 * m_k_strides[i];
227 }
228
229 if (side == Rhs && inner_dim_contiguous) {
230 eigen_assert(m_contract_strides[0] == 1);
231 linidx[0] += contract_val[0];
232 linidx[1] += contract_val[1];
233 } else {
234 linidx[0] += contract_val[0] * m_contract_strides[0];
235 linidx[1] += contract_val[1] * m_contract_strides[0];
236 }
237 }
238 return IndexPair<Index>(linidx[0], linidx[1]);
239 }
240
241 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
242 // Only claim alignment when we can compute the actual stride (ie when we're
243 // dealing with the lhs with inner_dim_contiguous. This is because the
244 // matrix-vector product relies on the stride when dealing with aligned inputs.
245 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
246 }
247 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
248 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
249 }
250
251 #ifdef EIGEN_USE_SYCL
252 // The placeholder accessors require to be bound to a command group handler for SYCL
253 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
254 m_tensor.bind(cgh);
255 }
256 #endif
257
258 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const {
259 return m_tensor;
260 }
261
262 const nocontract_t& nocontract_strides() const {
263 return m_nocontract_strides;
264 }
265 const nocontract_t& ij_strides() const { return m_ij_strides; }
266 const contract_t& contract_strides() const { return m_contract_strides; }
267 const contract_t& k_strides() const { return m_k_strides; }
268
269 protected:
270 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
271 const nocontract_t m_nocontract_strides;
272 const nocontract_t m_ij_strides;
273 const contract_t m_contract_strides;
274 const contract_t m_k_strides;
275};
276
277template<typename Scalar, typename Index, int side,
278 typename Tensor,
279 typename nocontract_t, typename contract_t,
280 int packet_size, bool inner_dim_contiguous,
281 bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
282class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_>
283{
284 public:
285 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
286
287 EIGEN_DEVICE_FUNC
288 BaseTensorContractionMapper(const Tensor& tensor,
289 const nocontract_t& nocontract_strides,
290 const nocontract_t& ij_strides,
291 const contract_t& contract_strides,
292 const contract_t& k_strides) :
293 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
294
295 template <typename PacketT,int AlignmentType>
296 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
297 typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type
298 load(Index i, Index j) const
299 {
300 // whole method makes column major assumption
301
302 // don't need to add offsets for now (because operator handles that)
303 // current code assumes packet size must be a multiple of 2
304 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
305
306 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
307 const Index index = this->computeIndex(i, j);
308 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
309 return this->m_tensor.template packet<AlignmentType>(index);
310 }
311
312 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
313 const Index first = indexPair.first;
314 const Index lastIdx = indexPair.second;
315
316 // We can always do optimized packet reads from left hand side right now, because
317 // the vertical matrix dimension on the left hand side is never contracting.
318 // On the right hand side we need to check if the contracting dimensions may have
319 // been shuffled first.
320 if (Tensor::PacketAccess &&
321 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
322 (lastIdx - first) == (packet_size - 1)) {
323
324 return this->m_tensor.template packet<AlignmentType>(first);
325 }
326
327 EIGEN_ALIGN_MAX Scalar data[packet_size];
328
329 data[0] = this->m_tensor.coeff(first);
330 EIGEN_UNROLL_LOOP
331 for (Index k = 1; k < packet_size - 1; k += 2) {
332 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
333 data[k] = this->m_tensor.coeff(internal_pair.first);
334 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
335 }
336 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
337
338 return pload<PacketT>(data);
339 }
340
341 template <typename PacketT,int AlignmentType>
342 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
343 typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type
344 load(Index i, Index j) const
345 {
346 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
347 EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
348
349 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
350 const Index first = indexPair.first;
351 const Index lastIdx = indexPair.second;
352
353 data[0] = this->m_tensor.coeff(first);
354 for (Index k = 1; k < requested_packet_size - 1; k += 2) {
355 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
356 data[k] = this->m_tensor.coeff(internal_pair.first);
357 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
358 }
359 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
360
361 return pload<PacketT>(data);
362 }
363
364 template <typename PacketT,int AlignmentType>
365 EIGEN_DEVICE_FUNC
366 EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
367 return this->load<PacketT,AlignmentType>(i,j);
368 }
369};
370
371
372template<typename Scalar, typename Index, int side,
373 typename Tensor,
374 typename nocontract_t, typename contract_t,
375 bool inner_dim_contiguous,
376 bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
377class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
378 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_>
379{
380 public:
381 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
382
383 EIGEN_DEVICE_FUNC
384 BaseTensorContractionMapper(const Tensor& tensor,
385 const nocontract_t& nocontract_strides,
386 const nocontract_t& ij_strides,
387 const contract_t& contract_strides,
388 const contract_t& k_strides) :
389 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
390
391 template <typename PacketT,int> EIGEN_DEVICE_FUNC
392 EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
393 EIGEN_ALIGN_MAX Scalar data[1];
394 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
395 return pload<PacketT>(data);
396 }
397 template <typename PacketT,int> EIGEN_DEVICE_FUNC
398 EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
399 EIGEN_ALIGN_MAX Scalar data[1];
400 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
401 return pload<PacketT>(data);
402 }
403};
404
405
406template<typename Scalar, typename Index, int side,
407 typename Tensor,
408 typename nocontract_t, typename contract_t,
409 int packet_size,
410 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
411class TensorContractionSubMapper {
412 public:
413
414 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
415 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
416 typedef Self LinearMapper;
417
418 enum {
419 // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
420 // TODO: we should also enable direct offsets for the Rhs case.
421 UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
422 };
423
424 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
425 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
426 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
427 // this offset every time we attempt to access a coefficient.
428 if (UseDirectOffsets) {
429 Index stride = m_base_mapper.stride();
430 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
431 }
432 }
433
434 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
435 if (UseDirectOffsets) {
436 return m_base_mapper(i, 0);
437 }
438 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
439 }
440 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
441 if (UseDirectOffsets) {
442 return m_base_mapper(i, j);
443 }
444 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
445 }
446
447 template <typename PacketT>
448 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const {
449 if (UseDirectOffsets) {
450 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, 0);
451 }
452 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, m_horiz_offset);
453 }
454
455 template <typename PacketT>
456 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
457 if (UseDirectOffsets) {
458 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j);
459 }
460 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
461 }
462
463 template <typename PacketT, int AlignmentType>
464 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
465 if (UseDirectOffsets) {
466 return m_base_mapper.template load<PacketT,AlignmentType>(i, j);
467 }
468 return m_base_mapper.template loadPacket<PacketT,AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
469 }
470
471 template <typename PacketT>
472 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
473 if (UseDirectOffsets) {
474 m_base_mapper.storePacket(i, 0, p);
475 }
476 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
477 }
478
479 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
480 if (UseDirectOffsets) {
481 return LinearMapper(m_base_mapper, i, j);
482 }
483 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
484 }
485
486 template <typename PacketT, int AlignmentType>
487 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
488 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
489 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
490 if (UseDirectOffsets) {
491 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i, 0);
492 }
493 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i + m_vert_offset, m_horiz_offset);
494 }
495
496 template <typename PacketT>
497 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
498 return false;
499 }
500
501 #ifdef EIGEN_USE_SYCL
502 // The placeholder accessors require to be bound to a command group handler for SYCL
503 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
504 m_base_mapper.bind(cgh);
505 }
506 #endif
507
508 const ParentMapper& base_mapper() const { return m_base_mapper; }
509 Index vert_offset() const { return m_vert_offset; }
510 Index horiz_offset() const { return m_horiz_offset; }
511
512 private:
513 ParentMapper m_base_mapper;
514 const Index m_vert_offset;
515 const Index m_horiz_offset;
516};
517
518
519template<typename Scalar_, typename Index, int side,
520 typename Tensor,
521 typename nocontract_t, typename contract_t,
522 int packet_size,
523 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
524class TensorContractionInputMapper
525 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
526
527 public:
528 typedef Scalar_ Scalar;
529 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
530 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
531 typedef SubMapper VectorMapper;
532
533 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
534 const nocontract_t& nocontract_strides,
535 const nocontract_t& ij_strides,
536 const contract_t& contract_strides,
537 const contract_t& k_strides)
538 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
539
540 EIGEN_DEVICE_FUNC
541 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
542 return SubMapper(*this, i, j);
543 }
544
545 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
546 return VectorMapper(*this, i, j);
547 }
548
549 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const {
550 return Base::m_tensor;
551 }
552};
553
554
555template <typename T> struct TensorContractionInputMapperTrait;
556
557template<typename Scalar_, typename Index_, int side_,
558 typename Tensor_,
559 typename nocontract_t_, typename contract_t_,
560 int packet_size_,
561 bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_, template <class> class MakePointer_>
562struct TensorContractionInputMapperTrait<TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_,
563 nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_,
564 inner_dim_reordered_, Alignment_, MakePointer_> > {
565
566 typedef Tensor_ XprType;
567 static const bool inner_dim_contiguous = inner_dim_contiguous_;
568 static const bool inner_dim_reordered = inner_dim_reordered_;
569 };
570
571
572} // end namespace internal
573} // end namespace Eigen
574
575#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index