Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorExpr.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12
13namespace Eigen {
14
30namespace internal {
31template<typename NullaryOp, typename XprType>
32struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
33 : traits<XprType>
34{
35 typedef traits<XprType> XprTraits;
36 typedef typename XprType::Scalar Scalar;
37 typedef typename XprType::Nested XprTypeNested;
38 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
39 static const int NumDimensions = XprTraits::NumDimensions;
40 static const int Layout = XprTraits::Layout;
41 typedef typename XprTraits::PointerType PointerType;
42 enum {
43 Flags = 0
44 };
45};
46
47} // end namespace internal
48
49
50
51template<typename NullaryOp, typename XprType>
52class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
53{
54 public:
55 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
56 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57 typedef typename XprType::CoeffReturnType CoeffReturnType;
58 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
59 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
60 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
61
62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
63 : m_xpr(xpr), m_functor(func) {}
64
65 EIGEN_DEVICE_FUNC
66 const typename internal::remove_all<typename XprType::Nested>::type&
67 nestedExpression() const { return m_xpr; }
68
69 EIGEN_DEVICE_FUNC
70 const NullaryOp& functor() const { return m_functor; }
71
72 protected:
73 typename XprType::Nested m_xpr;
74 const NullaryOp m_functor;
75};
76
77
78
79namespace internal {
80template<typename UnaryOp, typename XprType>
81struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
82 : traits<XprType>
83{
84 // TODO(phli): Add InputScalar, InputPacket. Check references to
85 // current Scalar/Packet to see if the intent is Input or Output.
86 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
87 typedef traits<XprType> XprTraits;
88 typedef typename XprType::Nested XprTypeNested;
89 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
90 static const int NumDimensions = XprTraits::NumDimensions;
91 static const int Layout = XprTraits::Layout;
92 typedef typename TypeConversion<Scalar,
93 typename XprTraits::PointerType
94 >::type
95 PointerType;
96};
97
98template<typename UnaryOp, typename XprType>
99struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
100{
101 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
102};
103
104template<typename UnaryOp, typename XprType>
105struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
106{
107 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
108};
109
110} // end namespace internal
111
112
113
114template<typename UnaryOp, typename XprType>
115class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
116{
117 public:
118 // TODO(phli): Add InputScalar, InputPacket. Check references to
119 // current Scalar/Packet to see if the intent is Input or Output.
120 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
121 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
122 typedef Scalar CoeffReturnType;
123 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
124 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
125 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
126
127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
128 : m_xpr(xpr), m_functor(func) {}
129
130 EIGEN_DEVICE_FUNC
131 const UnaryOp& functor() const { return m_functor; }
132
134 EIGEN_DEVICE_FUNC
135 const typename internal::remove_all<typename XprType::Nested>::type&
136 nestedExpression() const { return m_xpr; }
137
138 protected:
139 typename XprType::Nested m_xpr;
140 const UnaryOp m_functor;
141};
142
143
144namespace internal {
145template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
146struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
147{
148 // Type promotion to handle the case where the types of the lhs and the rhs
149 // are different.
150 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
151 // current Scalar/Packet to see if the intent is Inputs or Output.
152 typedef typename result_of<
153 BinaryOp(typename LhsXprType::Scalar,
154 typename RhsXprType::Scalar)>::type Scalar;
155 typedef traits<LhsXprType> XprTraits;
156 typedef typename promote_storage_type<
157 typename traits<LhsXprType>::StorageKind,
158 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
159 typedef typename promote_index_type<
160 typename traits<LhsXprType>::Index,
161 typename traits<RhsXprType>::Index>::type Index;
162 typedef typename LhsXprType::Nested LhsNested;
163 typedef typename RhsXprType::Nested RhsNested;
164 typedef typename remove_reference<LhsNested>::type _LhsNested;
165 typedef typename remove_reference<RhsNested>::type _RhsNested;
166 static const int NumDimensions = XprTraits::NumDimensions;
167 static const int Layout = XprTraits::Layout;
168 typedef typename TypeConversion<Scalar,
169 typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
170 typename traits<LhsXprType>::PointerType,
171 typename traits<RhsXprType>::PointerType>::type
172 >::type
173 PointerType;
174 enum {
175 Flags = 0
176 };
177};
178
179template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
180struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
181{
182 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
183};
184
185template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
186struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
187{
188 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
189};
190
191} // end namespace internal
192
193
194
195template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
196class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
197{
198 public:
199 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
200 // current Scalar/Packet to see if the intent is Inputs or Output.
201 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
202 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
203 typedef Scalar CoeffReturnType;
204 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
205 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
206 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
207
208 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
209 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
210
211 EIGEN_DEVICE_FUNC
212 const BinaryOp& functor() const { return m_functor; }
213
215 EIGEN_DEVICE_FUNC
216 const typename internal::remove_all<typename LhsXprType::Nested>::type&
217 lhsExpression() const { return m_lhs_xpr; }
218
219 EIGEN_DEVICE_FUNC
220 const typename internal::remove_all<typename RhsXprType::Nested>::type&
221 rhsExpression() const { return m_rhs_xpr; }
222
223 protected:
224 typename LhsXprType::Nested m_lhs_xpr;
225 typename RhsXprType::Nested m_rhs_xpr;
226 const BinaryOp m_functor;
227};
228
229
230namespace internal {
231template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
232struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
233{
234 // Type promotion to handle the case where the types of the args are different.
235 typedef typename result_of<
236 TernaryOp(typename Arg1XprType::Scalar,
237 typename Arg2XprType::Scalar,
238 typename Arg3XprType::Scalar)>::type Scalar;
239 typedef traits<Arg1XprType> XprTraits;
240 typedef typename traits<Arg1XprType>::StorageKind StorageKind;
241 typedef typename traits<Arg1XprType>::Index Index;
242 typedef typename Arg1XprType::Nested Arg1Nested;
243 typedef typename Arg2XprType::Nested Arg2Nested;
244 typedef typename Arg3XprType::Nested Arg3Nested;
245 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
246 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
247 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
248 static const int NumDimensions = XprTraits::NumDimensions;
249 static const int Layout = XprTraits::Layout;
250 typedef typename TypeConversion<Scalar,
251 typename conditional<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
252 typename traits<Arg2XprType>::PointerType,
253 typename traits<Arg3XprType>::PointerType>::type
254 >::type
255 PointerType;
256 enum {
257 Flags = 0
258 };
259};
260
261template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
262struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
263{
264 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
265};
266
267template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
268struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
269{
270 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
271};
272
273} // end namespace internal
274
275
276
277template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
278class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
279{
280 public:
281 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
282 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
283 typedef Scalar CoeffReturnType;
284 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
285 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
286 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
287
288 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
289 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
290
291 EIGEN_DEVICE_FUNC
292 const TernaryOp& functor() const { return m_functor; }
293
295 EIGEN_DEVICE_FUNC
296 const typename internal::remove_all<typename Arg1XprType::Nested>::type&
297 arg1Expression() const { return m_arg1_xpr; }
298
299 EIGEN_DEVICE_FUNC
300 const typename internal::remove_all<typename Arg2XprType::Nested>::type&
301 arg2Expression() const { return m_arg2_xpr; }
302
303 EIGEN_DEVICE_FUNC
304 const typename internal::remove_all<typename Arg3XprType::Nested>::type&
305 arg3Expression() const { return m_arg3_xpr; }
306
307 protected:
308 typename Arg1XprType::Nested m_arg1_xpr;
309 typename Arg2XprType::Nested m_arg2_xpr;
310 typename Arg3XprType::Nested m_arg3_xpr;
311 const TernaryOp m_functor;
312};
313
314
315namespace internal {
316template<typename IfXprType, typename ThenXprType, typename ElseXprType>
317struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
318 : traits<ThenXprType>
319{
320 typedef typename traits<ThenXprType>::Scalar Scalar;
321 typedef traits<ThenXprType> XprTraits;
322 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
323 typename traits<ElseXprType>::StorageKind>::ret StorageKind;
324 typedef typename promote_index_type<typename traits<ElseXprType>::Index,
325 typename traits<ThenXprType>::Index>::type Index;
326 typedef typename IfXprType::Nested IfNested;
327 typedef typename ThenXprType::Nested ThenNested;
328 typedef typename ElseXprType::Nested ElseNested;
329 static const int NumDimensions = XprTraits::NumDimensions;
330 static const int Layout = XprTraits::Layout;
331 typedef typename conditional<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
332 typename traits<ThenXprType>::PointerType,
333 typename traits<ElseXprType>::PointerType>::type PointerType;
334};
335
336template<typename IfXprType, typename ThenXprType, typename ElseXprType>
337struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
338{
339 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
340};
341
342template<typename IfXprType, typename ThenXprType, typename ElseXprType>
343struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
344{
345 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
346};
347
348} // end namespace internal
349
350
351template<typename IfXprType, typename ThenXprType, typename ElseXprType>
352class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
353{
354 public:
355 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
356 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
357 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
358 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
359 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
360 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
361 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
362
363 EIGEN_DEVICE_FUNC
364 TensorSelectOp(const IfXprType& a_condition,
365 const ThenXprType& a_then,
366 const ElseXprType& a_else)
367 : m_condition(a_condition), m_then(a_then), m_else(a_else)
368 { }
369
370 EIGEN_DEVICE_FUNC
371 const IfXprType& ifExpression() const { return m_condition; }
372
373 EIGEN_DEVICE_FUNC
374 const ThenXprType& thenExpression() const { return m_then; }
375
376 EIGEN_DEVICE_FUNC
377 const ElseXprType& elseExpression() const { return m_else; }
378
379 protected:
380 typename IfXprType::Nested m_condition;
381 typename ThenXprType::Nested m_then;
382 typename ElseXprType::Nested m_else;
383};
384
385
386} // end namespace Eigen
387
388#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index