Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
GPU/SpecialFunctions.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GPU_SPECIALFUNCTIONS_H
11#define EIGEN_GPU_SPECIALFUNCTIONS_H
12
13namespace Eigen {
14
15namespace internal {
16
17// Make sure this is only available when targeting a GPU: we don't want to
18// introduce conflicts between these packet_traits definitions and the ones
19// we'll use on the host side (SSE, AVX, ...)
20#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU)
21
22template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
23float4 plgamma<float4>(const float4& a)
24{
25 return make_float4(lgammaf(a.x), lgammaf(a.y), lgammaf(a.z), lgammaf(a.w));
26}
27
28template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
29double2 plgamma<double2>(const double2& a)
30{
31 using numext::lgamma;
32 return make_double2(lgamma(a.x), lgamma(a.y));
33}
34
35template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
36float4 pdigamma<float4>(const float4& a)
37{
38 using numext::digamma;
39 return make_float4(digamma(a.x), digamma(a.y), digamma(a.z), digamma(a.w));
40}
41
42template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
43double2 pdigamma<double2>(const double2& a)
44{
45 using numext::digamma;
46 return make_double2(digamma(a.x), digamma(a.y));
47}
48
49template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
50float4 pzeta<float4>(const float4& x, const float4& q)
51{
52 using numext::zeta;
53 return make_float4(zeta(x.x, q.x), zeta(x.y, q.y), zeta(x.z, q.z), zeta(x.w, q.w));
54}
55
56template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
57double2 pzeta<double2>(const double2& x, const double2& q)
58{
59 using numext::zeta;
60 return make_double2(zeta(x.x, q.x), zeta(x.y, q.y));
61}
62
63template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
64float4 ppolygamma<float4>(const float4& n, const float4& x)
65{
66 using numext::polygamma;
67 return make_float4(polygamma(n.x, x.x), polygamma(n.y, x.y), polygamma(n.z, x.z), polygamma(n.w, x.w));
68}
69
70template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
71double2 ppolygamma<double2>(const double2& n, const double2& x)
72{
73 using numext::polygamma;
74 return make_double2(polygamma(n.x, x.x), polygamma(n.y, x.y));
75}
76
77template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
78float4 perf<float4>(const float4& a)
79{
80 return make_float4(erff(a.x), erff(a.y), erff(a.z), erff(a.w));
81}
82
83template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
84double2 perf<double2>(const double2& a)
85{
86 using numext::erf;
87 return make_double2(erf(a.x), erf(a.y));
88}
89
90template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
91float4 perfc<float4>(const float4& a)
92{
93 using numext::erfc;
94 return make_float4(erfc(a.x), erfc(a.y), erfc(a.z), erfc(a.w));
95}
96
97template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
98double2 perfc<double2>(const double2& a)
99{
100 using numext::erfc;
101 return make_double2(erfc(a.x), erfc(a.y));
102}
103
104template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
105float4 pndtri<float4>(const float4& a)
106{
107 using numext::ndtri;
108 return make_float4(ndtri(a.x), ndtri(a.y), ndtri(a.z), ndtri(a.w));
109}
110
111template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
112double2 pndtri<double2>(const double2& a)
113{
114 using numext::ndtri;
115 return make_double2(ndtri(a.x), ndtri(a.y));
116}
117
118template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
119float4 pigamma<float4>(const float4& a, const float4& x)
120{
121 using numext::igamma;
122 return make_float4(
123 igamma(a.x, x.x),
124 igamma(a.y, x.y),
125 igamma(a.z, x.z),
126 igamma(a.w, x.w));
127}
128
129template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
130double2 pigamma<double2>(const double2& a, const double2& x)
131{
132 using numext::igamma;
133 return make_double2(igamma(a.x, x.x), igamma(a.y, x.y));
134}
135
136template <>
137EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pigamma_der_a<float4>(
138 const float4& a, const float4& x) {
139 using numext::igamma_der_a;
140 return make_float4(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y),
141 igamma_der_a(a.z, x.z), igamma_der_a(a.w, x.w));
142}
143
144template <>
145EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
146pigamma_der_a<double2>(const double2& a, const double2& x) {
147 using numext::igamma_der_a;
148 return make_double2(igamma_der_a(a.x, x.x), igamma_der_a(a.y, x.y));
149}
150
151template <>
152EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pgamma_sample_der_alpha<float4>(
153 const float4& alpha, const float4& sample) {
154 using numext::gamma_sample_der_alpha;
155 return make_float4(
156 gamma_sample_der_alpha(alpha.x, sample.x),
157 gamma_sample_der_alpha(alpha.y, sample.y),
158 gamma_sample_der_alpha(alpha.z, sample.z),
159 gamma_sample_der_alpha(alpha.w, sample.w));
160}
161
162template <>
163EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
164pgamma_sample_der_alpha<double2>(const double2& alpha, const double2& sample) {
165 using numext::gamma_sample_der_alpha;
166 return make_double2(
167 gamma_sample_der_alpha(alpha.x, sample.x),
168 gamma_sample_der_alpha(alpha.y, sample.y));
169}
170
171template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
172float4 pigammac<float4>(const float4& a, const float4& x)
173{
174 using numext::igammac;
175 return make_float4(
176 igammac(a.x, x.x),
177 igammac(a.y, x.y),
178 igammac(a.z, x.z),
179 igammac(a.w, x.w));
180}
181
182template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
183double2 pigammac<double2>(const double2& a, const double2& x)
184{
185 using numext::igammac;
186 return make_double2(igammac(a.x, x.x), igammac(a.y, x.y));
187}
188
189template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
190float4 pbetainc<float4>(const float4& a, const float4& b, const float4& x)
191{
192 using numext::betainc;
193 return make_float4(
194 betainc(a.x, b.x, x.x),
195 betainc(a.y, b.y, x.y),
196 betainc(a.z, b.z, x.z),
197 betainc(a.w, b.w, x.w));
198}
199
200template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
201double2 pbetainc<double2>(const double2& a, const double2& b, const double2& x)
202{
203 using numext::betainc;
204 return make_double2(betainc(a.x, b.x, x.x), betainc(a.y, b.y, x.y));
205}
206
207template <>
208EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_i0e<float4>(const float4& x) {
209 using numext::bessel_i0e;
210 return make_float4(bessel_i0e(x.x), bessel_i0e(x.y), bessel_i0e(x.z), bessel_i0e(x.w));
211}
212
213template <>
214EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
215pbessel_i0e<double2>(const double2& x) {
216 using numext::bessel_i0e;
217 return make_double2(bessel_i0e(x.x), bessel_i0e(x.y));
218}
219
220template <>
221EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_i0<float4>(const float4& x) {
222 using numext::bessel_i0;
223 return make_float4(bessel_i0(x.x), bessel_i0(x.y), bessel_i0(x.z), bessel_i0(x.w));
224}
225
226template <>
227EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
228pbessel_i0<double2>(const double2& x) {
229 using numext::bessel_i0;
230 return make_double2(bessel_i0(x.x), bessel_i0(x.y));
231}
232
233template <>
234EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_i1e<float4>(const float4& x) {
235 using numext::bessel_i1e;
236 return make_float4(bessel_i1e(x.x), bessel_i1e(x.y), bessel_i1e(x.z), bessel_i1e(x.w));
237}
238
239template <>
240EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
241pbessel_i1e<double2>(const double2& x) {
242 using numext::bessel_i1e;
243 return make_double2(bessel_i1e(x.x), bessel_i1e(x.y));
244}
245
246template <>
247EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_i1<float4>(const float4& x) {
248 using numext::bessel_i1;
249 return make_float4(bessel_i1(x.x), bessel_i1(x.y), bessel_i1(x.z), bessel_i1(x.w));
250}
251
252template <>
253EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
254pbessel_i1<double2>(const double2& x) {
255 using numext::bessel_i1;
256 return make_double2(bessel_i1(x.x), bessel_i1(x.y));
257}
258
259template <>
260EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_k0e<float4>(const float4& x) {
261 using numext::bessel_k0e;
262 return make_float4(bessel_k0e(x.x), bessel_k0e(x.y), bessel_k0e(x.z), bessel_k0e(x.w));
263}
264
265template <>
266EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
267pbessel_k0e<double2>(const double2& x) {
268 using numext::bessel_k0e;
269 return make_double2(bessel_k0e(x.x), bessel_k0e(x.y));
270}
271
272template <>
273EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_k0<float4>(const float4& x) {
274 using numext::bessel_k0;
275 return make_float4(bessel_k0(x.x), bessel_k0(x.y), bessel_k0(x.z), bessel_k0(x.w));
276}
277
278template <>
279EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
280pbessel_k0<double2>(const double2& x) {
281 using numext::bessel_k0;
282 return make_double2(bessel_k0(x.x), bessel_k0(x.y));
283}
284
285template <>
286EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_k1e<float4>(const float4& x) {
287 using numext::bessel_k1e;
288 return make_float4(bessel_k1e(x.x), bessel_k1e(x.y), bessel_k1e(x.z), bessel_k1e(x.w));
289}
290
291template <>
292EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
293pbessel_k1e<double2>(const double2& x) {
294 using numext::bessel_k1e;
295 return make_double2(bessel_k1e(x.x), bessel_k1e(x.y));
296}
297
298template <>
299EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_k1<float4>(const float4& x) {
300 using numext::bessel_k1;
301 return make_float4(bessel_k1(x.x), bessel_k1(x.y), bessel_k1(x.z), bessel_k1(x.w));
302}
303
304template <>
305EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
306pbessel_k1<double2>(const double2& x) {
307 using numext::bessel_k1;
308 return make_double2(bessel_k1(x.x), bessel_k1(x.y));
309}
310
311template <>
312EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_j0<float4>(const float4& x) {
313 using numext::bessel_j0;
314 return make_float4(bessel_j0(x.x), bessel_j0(x.y), bessel_j0(x.z), bessel_j0(x.w));
315}
316
317template <>
318EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
319pbessel_j0<double2>(const double2& x) {
320 using numext::bessel_j0;
321 return make_double2(bessel_j0(x.x), bessel_j0(x.y));
322}
323
324template <>
325EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_j1<float4>(const float4& x) {
326 using numext::bessel_j1;
327 return make_float4(bessel_j1(x.x), bessel_j1(x.y), bessel_j1(x.z), bessel_j1(x.w));
328}
329
330template <>
331EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
332pbessel_j1<double2>(const double2& x) {
333 using numext::bessel_j1;
334 return make_double2(bessel_j1(x.x), bessel_j1(x.y));
335}
336
337template <>
338EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_y0<float4>(const float4& x) {
339 using numext::bessel_y0;
340 return make_float4(bessel_y0(x.x), bessel_y0(x.y), bessel_y0(x.z), bessel_y0(x.w));
341}
342
343template <>
344EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
345pbessel_y0<double2>(const double2& x) {
346 using numext::bessel_y0;
347 return make_double2(bessel_y0(x.x), bessel_y0(x.y));
348}
349
350template <>
351EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pbessel_y1<float4>(const float4& x) {
352 using numext::bessel_y1;
353 return make_float4(bessel_y1(x.x), bessel_y1(x.y), bessel_y1(x.z), bessel_y1(x.w));
354}
355
356template <>
357EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
358pbessel_y1<double2>(const double2& x) {
359 using numext::bessel_y1;
360 return make_double2(bessel_y1(x.x), bessel_y1(x.y));
361}
362
363#endif
364
365} // end namespace internal
366
367} // end namespace Eigen
368
369#endif // EIGEN_GPU_SPECIALFUNCTIONS_H
Namespace containing all symbols from the Eigen library.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_y1_op< typename Derived::Scalar >, const Derived > bessel_y1(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:278
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_k0e_op< typename Derived::Scalar >, const Derived > bessel_k0e(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:145
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_k0_op< typename Derived::Scalar >, const Derived > bessel_k0(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:122
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:90
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:51
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_k1_op< typename Derived::Scalar >, const Derived > bessel_k1(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:167
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_i1_op< typename Derived::Scalar >, const Derived > bessel_i1(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:77
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_i0e_op< typename Derived::Scalar >, const Derived > bessel_i0e(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:55
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_lgamma_op< typename Derived::Scalar >, const Derived > lgamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_i1e_op< typename Derived::Scalar >, const Derived > bessel_i1e(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:100
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_j1_op< typename Derived::Scalar >, const Derived > bessel_j1(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:256
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_y0_op< typename Derived::Scalar >, const Derived > bessel_y0(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:234
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_i0_op< typename Derived::Scalar >, const Derived > bessel_i0(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:32
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erf_op< typename Derived::Scalar >, const Derived > erf(const Eigen::ArrayBase< Derived > &x)
const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
Definition: TensorGlobalFunctions.h:24
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_k1e_op< typename Derived::Scalar >, const Derived > bessel_k1e(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:190
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
Definition: SpecialFunctionsArrayAPI.h:72
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_erfc_op< typename Derived::Scalar >, const Derived > erfc(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_ndtri_op< typename Derived::Scalar >, const Derived > ndtri(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_digamma_op< typename Derived::Scalar >, const Derived > digamma(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
Definition: SpecialFunctionsArrayAPI.h:112
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_bessel_j0_op< typename Derived::Scalar >, const Derived > bessel_j0(const Eigen::ArrayBase< Derived > &x)
Definition: BesselFunctionsArrayAPI.h:212
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Definition: SpecialFunctionsArrayAPI.h:28
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_zeta_op< typename DerivedX::Scalar >, const DerivedX, const DerivedQ > zeta(const Eigen::ArrayBase< DerivedX > &x, const Eigen::ArrayBase< DerivedQ > &q)
Definition: SpecialFunctionsArrayAPI.h:156