Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
TensorCustomOp.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13namespace Eigen {
14
22namespace internal {
23template<typename CustomUnaryFunc, typename XprType>
24struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25{
26 typedef typename XprType::Scalar Scalar;
27 typedef typename XprType::StorageKind StorageKind;
28 typedef typename XprType::Index Index;
29 typedef typename XprType::Nested Nested;
30 typedef typename remove_reference<Nested>::type _Nested;
31 static const int NumDimensions = traits<XprType>::NumDimensions;
32 static const int Layout = traits<XprType>::Layout;
33 typedef typename traits<XprType>::PointerType PointerType;
34};
35
36template<typename CustomUnaryFunc, typename XprType>
37struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
38{
39 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>EIGEN_DEVICE_REF type;
40};
41
42template<typename CustomUnaryFunc, typename XprType>
43struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
44{
45 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
46};
47
48} // end namespace internal
49
50
51
52template<typename CustomUnaryFunc, typename XprType>
53class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
54{
55 public:
56 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58 typedef typename XprType::CoeffReturnType CoeffReturnType;
59 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
60 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
61 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
62
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
64 : m_expr(expr), m_func(func) {}
65
66 EIGEN_DEVICE_FUNC
67 const CustomUnaryFunc& func() const { return m_func; }
68
69 EIGEN_DEVICE_FUNC
70 const typename internal::remove_all<typename XprType::Nested>::type&
71 expression() const { return m_expr; }
72
73 protected:
74 typename XprType::Nested m_expr;
75 const CustomUnaryFunc m_func;
76};
77
78
79// Eval as rvalue
80template<typename CustomUnaryFunc, typename XprType, typename Device>
81struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
82{
84 typedef typename internal::traits<ArgType>::Index Index;
85 static const int NumDims = internal::traits<ArgType>::NumDimensions;
86 typedef DSizes<Index, NumDims> Dimensions;
87 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
88 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
89 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
91 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
92 typedef StorageMemory<CoeffReturnType, Device> Storage;
93 typedef typename Storage::Type EvaluatorPointerType;
94
95 enum {
96 IsAligned = false,
97 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
98 BlockAccess = false,
99 PreferBlockAccess = false,
101 CoordAccess = false, // to be implemented
102 RawAccess = false
103 };
104
105 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
106 typedef internal::TensorBlockNotImplemented TensorBlock;
107 //===--------------------------------------------------------------------===//
108
109 EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
110 : m_op(op), m_device(device), m_result(NULL)
111 {
112 m_dimensions = op.func().dimensions(op.expression());
113 }
114
115 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
116
117 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
118 if (data) {
119 evalTo(data);
120 return false;
121 } else {
122 m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
123 m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))));
124 evalTo(m_result);
125 return true;
126 }
127 }
128
129 EIGEN_STRONG_INLINE void cleanup() {
130 if (m_result) {
131 m_device.deallocate_temp(m_result);
132 m_result = NULL;
133 }
134 }
135
136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
137 return m_result[index];
138 }
139
140 template<int LoadMode>
141 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
142 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
143 }
144
145 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
146 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
147 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
148 }
149
150 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
151
152#ifdef EIGEN_USE_SYCL
153 // binding placeholder accessors to a command group handler for SYCL
154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
155 m_result.bind(cgh);
156 }
157#endif
158
159 protected:
160 void evalTo(EvaluatorPointerType data) {
161 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(m_device.get(data), m_dimensions);
162 m_op.func().eval(m_op.expression(), result, m_device);
163 }
164
165 Dimensions m_dimensions;
166 const ArgType m_op;
167 const Device EIGEN_DEVICE_REF m_device;
168 EvaluatorPointerType m_result;
169};
170
171
172
180namespace internal {
181template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
182struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
183{
184 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
185 typename RhsXprType::Scalar>::ret Scalar;
186 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
187 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
188 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
189 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
190 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
191 typename traits<RhsXprType>::Index>::type Index;
192 typedef typename LhsXprType::Nested LhsNested;
193 typedef typename RhsXprType::Nested RhsNested;
194 typedef typename remove_reference<LhsNested>::type _LhsNested;
195 typedef typename remove_reference<RhsNested>::type _RhsNested;
196 static const int NumDimensions = traits<LhsXprType>::NumDimensions;
197 static const int Layout = traits<LhsXprType>::Layout;
198 typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
199 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
200};
201
202template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
203struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
204{
205 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
206};
207
208template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
209struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
210{
211 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
212};
213
214} // end namespace internal
215
216
217
218template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
219class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
220{
221 public:
222 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
223 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
224 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
225 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
226 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
227 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
228
229 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
230
231 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
232
233 EIGEN_DEVICE_FUNC
234 const CustomBinaryFunc& func() const { return m_func; }
235
236 EIGEN_DEVICE_FUNC
237 const typename internal::remove_all<typename LhsXprType::Nested>::type&
238 lhsExpression() const { return m_lhs_xpr; }
239
240 EIGEN_DEVICE_FUNC
241 const typename internal::remove_all<typename RhsXprType::Nested>::type&
242 rhsExpression() const { return m_rhs_xpr; }
243
244 protected:
245 typename LhsXprType::Nested m_lhs_xpr;
246 typename RhsXprType::Nested m_rhs_xpr;
247 const CustomBinaryFunc m_func;
248};
249
250
251// Eval as rvalue
252template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
253struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
254{
256 typedef typename internal::traits<XprType>::Index Index;
257 static const int NumDims = internal::traits<XprType>::NumDimensions;
258 typedef DSizes<Index, NumDims> Dimensions;
259 typedef typename XprType::Scalar Scalar;
260 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
261 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
262 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
263
264 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
265 typedef StorageMemory<CoeffReturnType, Device> Storage;
266 typedef typename Storage::Type EvaluatorPointerType;
267
268 enum {
269 IsAligned = false,
270 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
271 BlockAccess = false,
272 PreferBlockAccess = false,
274 CoordAccess = false, // to be implemented
275 RawAccess = false
276 };
277
278 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
279 typedef internal::TensorBlockNotImplemented TensorBlock;
280 //===--------------------------------------------------------------------===//
281
282 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
283 : m_op(op), m_device(device), m_result(NULL)
284 {
285 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
286 }
287
288 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
289
290 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
291 if (data) {
292 evalTo(data);
293 return false;
294 } else {
295 m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
296 m_device.allocate_temp(dimensions().TotalSize() * sizeof(CoeffReturnType))));
297 evalTo(m_result);
298 return true;
299 }
300 }
301
302 EIGEN_STRONG_INLINE void cleanup() {
303 if (m_result != NULL) {
304 m_device.deallocate_temp(m_result);
305 m_result = NULL;
306 }
307 }
308
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
310 return m_result[index];
311 }
312
313 template<int LoadMode>
314 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
315 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
316 }
317
318 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
319 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
320 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
321 }
322
323 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
324
325#ifdef EIGEN_USE_SYCL
326 // binding placeholder accessors to a command group handler for SYCL
327 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
328 m_result.bind(cgh);
329 }
330#endif
331
332 protected:
333 void evalTo(EvaluatorPointerType data) {
334 TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(m_device.get(data), m_dimensions);
335 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
336 }
337
338 Dimensions m_dimensions;
339 const XprType m_op;
340 const Device EIGEN_DEVICE_REF m_device;
341 EvaluatorPointerType m_result;
342};
343
344
345} // end namespace Eigen
346
347#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
The tensor base class.
Definition: TensorForwardDeclarations.h:56
Tensor custom class.
Definition: TensorCustomOp.h:220
Tensor custom class.
Definition: TensorCustomOp.h:54
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:29