From b14187c368a1185a2a1dc0b6bcbcea1bab9bfef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20=27Griwes=27=20Dominiak?= Date: Wed, 20 Mar 2019 19:19:42 +0100 Subject: [PATCH] Make thrust::/std::complex interop __device__ qualified for C++11+. The functions to construct, assign and compare thrust::complex values from and with std::complex values were marked __host__ since forever, because access to std::complex is performed using member functions. However, in C++11, an explicit permission has been given to reinterpret_cast std::complex values as arrays of two elements of its template parameter, allowing us to implement a __device__-compatible set of those interop functions, when compiling for C++11. For C++03, they are still only __host__-qualified. Bug 2502854 --- testing/complex.cu | 24 ++++++++ thrust/complex.h | 96 +++++++++++++++++++------------ thrust/detail/complex/complex.inl | 88 ++++++++++++++-------------- 3 files changed, 126 insertions(+), 82 deletions(-) diff --git a/testing/complex.cu b/testing/complex.cu index 91256fd6b..cf46a6e87 100644 --- a/testing/complex.cu +++ b/testing/complex.cu @@ -284,3 +284,27 @@ struct TestComplexStreamOperators }; SimpleUnitTest TestComplexStreamOperatorsInstance; + +#if THRUST_CPP_DIALECT >= 2011 +template +struct TestComplexStdComplexDeviceInterop +{ + void operator()() + { + thrust::host_vector data = unittest::random_samples(6); + std::vector > vec(10); + vec[0] = std::complex(data[0], data[1]); + vec[1] = std::complex(data[2], data[3]); + vec[2] = std::complex(data[4], data[5]); + + thrust::device_vector > device_vec = vec; + ASSERT_ALMOST_EQUAL(vec[0].real(), thrust::complex(device_vec[0]).real()); + ASSERT_ALMOST_EQUAL(vec[0].imag(), thrust::complex(device_vec[0]).imag()); + ASSERT_ALMOST_EQUAL(vec[1].real(), thrust::complex(device_vec[1]).real()); + ASSERT_ALMOST_EQUAL(vec[1].imag(), thrust::complex(device_vec[1]).imag()); + ASSERT_ALMOST_EQUAL(vec[2].real(), thrust::complex(device_vec[2]).real()); + ASSERT_ALMOST_EQUAL(vec[2].imag(), thrust::complex(device_vec[2]).imag()); + } +}; +SimpleUnitTest TestComplexStdComplexDeviceInteropInstance; +#endif diff --git a/thrust/complex.h b/thrust/complex.h index ae6182253..5f2730115 100644 --- a/thrust/complex.h +++ b/thrust/complex.h @@ -1,5 +1,5 @@ /* - * Copyright 2008-2018 NVIDIA Corporation + * Copyright 2008-2019 NVIDIA Corporation * Copyright 2013 Filipe RNC Maia * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,11 +28,27 @@ #include #include +#if THRUST_CPP_DIALECT >= 2011 +# define THRUST_STD_COMPLEX_REAL(z) \ + reinterpret_cast< \ + const typename thrust::detail::remove_reference::type::value_type (&)[2] \ + >(z)[0] +# define THRUST_STD_COMPLEX_IMAG(z) \ + reinterpret_cast< \ + const typename thrust::detail::remove_reference::type::value_type (&)[2] \ + >(z)[1] +# define THRUST_STD_COMPLEX_DEVICE __device__ +#else +# define THRUST_STD_COMPLEX_REAL(z) (z).real() +# define THRUST_STD_COMPLEX_IMAG(z) (z).imag() +# define THRUST_STD_COMPLEX_DEVICE +#endif + namespace thrust { /* - * Calls to the standard math library from inside the thrust namespace + * Calls to the standard math library from inside the thrust namespace * with real arguments require explicit scope otherwise they will fail * to resolve as it will find the equivalent complex function but then * fail to match the template, and give up looking for other scopes. @@ -112,7 +128,7 @@ struct complex * * \param z The \p complex to copy from. */ - __host__ + __host__ THRUST_STD_COMPLEX_DEVICE complex(const std::complex& z); /*! This converting copy constructor copies from a std::complex with @@ -122,8 +138,8 @@ struct complex * * \tparam U is convertible to \c value_type. */ - template - __host__ + template + __host__ THRUST_STD_COMPLEX_DEVICE complex(const std::complex& z); @@ -162,7 +178,7 @@ struct complex * * \param z The \p complex to copy from. */ - __host__ + __host__ THRUST_STD_COMPLEX_DEVICE complex& operator=(const std::complex& z); /*! Assign `z.real()` and `z.imag()` to the real and imaginary parts of this @@ -172,8 +188,8 @@ struct complex * * \tparam U is convertible to \c value_type. */ - template - __host__ + template + __host__ THRUST_STD_COMPLEX_DEVICE complex& operator=(const std::complex& z); @@ -184,7 +200,7 @@ struct complex * \p complex. * * \param z The \p complex to be added. - * + * * \tparam U is convertible to \c value_type. */ template @@ -248,7 +264,7 @@ struct complex /*! Multiplies this \p complex by a scalar and assigns the result * to this \p complex. - * + * * \param z The scalar to be multiplied. * * \tparam U is convertible to \c value_type. @@ -259,7 +275,7 @@ struct complex /*! Divides this \p complex by a scalar and assigns the result to * this \p complex. - * + * * \param z The scalar to be divided. * * \tparam U is convertible to \c value_type. @@ -270,7 +286,7 @@ struct complex - /* --- Getter functions --- + /* --- Getter functions --- * The volatile ones are there to help for example * with certain reductions optimizations */ @@ -297,7 +313,7 @@ struct complex - /* --- Setter functions --- + /* --- Setter functions --- * The volatile ones are there to help for example * with certain reductions optimizations */ @@ -409,8 +425,8 @@ complex::type> polar(const T0& m, const T1& theta = T1()); /*! Returns the projection of a \p complex on the Riemann sphere. - * For all finite \p complex it returns the argument. For \p complexs - * with a non finite part returns (INFINITY,+/-0) where the sign of + * For all finite \p complex it returns the argument. For \p complexs + * with a non finite part returns (INFINITY,+/-0) where the sign of * the zero matches the sign of the imaginary part of the argument. * * \param z The \p complex argument. @@ -424,7 +440,7 @@ complex proj(const T& z); /* --- Binary Arithmetic operators --- */ /*! Adds two \p complex numbers. - * + * * The value types of the two \p complex types should be compatible and the * type of the returned \p complex is the promoted type of the two arguments. * @@ -437,7 +453,7 @@ complex::type> operator+(const complex& x, const complex& y); /*! Adds a scalar to a \p complex number. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -450,7 +466,7 @@ complex::type> operator+(const complex& x, const T1& y); /*! Adds a \p complex number to a scalar. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -463,7 +479,7 @@ complex::type> operator+(const T0& x, const complex& y); /*! Subtracts two \p complex numbers. - * + * * The value types of the two \p complex types should be compatible and the * type of the returned \p complex is the promoted type of the two arguments. * @@ -476,7 +492,7 @@ complex::type> operator-(const complex& x, const complex& y); /*! Subtracts a scalar from a \p complex number. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -489,7 +505,7 @@ complex::type> operator-(const complex& x, const T1& y); /*! Subtracts a \p complex number from a scalar. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -502,7 +518,7 @@ complex::type> operator-(const T0& x, const complex& y); /*! Multiplies two \p complex numbers. - * + * * The value types of the two \p complex types should be compatible and the * type of the returned \p complex is the promoted type of the two arguments. * @@ -525,7 +541,7 @@ complex::type> operator*(const complex& x, const T1& y); /*! Multiplies a scalar by a \p complex number. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -538,7 +554,7 @@ complex::type> operator*(const T0& x, const complex& y); /*! Divides two \p complex numbers. - * + * * The value types of the two \p complex types should be compatible and the * type of the returned \p complex is the promoted type of the two arguments. * @@ -551,7 +567,7 @@ complex::type> operator/(const complex& x, const complex& y); /*! Divides a \p complex number by a scalar. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -564,7 +580,7 @@ complex::type> operator/(const complex& x, const T1& y); /*! Divides a scalar by a \p complex number. - * + * * The value type of the \p complex should be compatible with the scalar and * the type of the returned \p complex is the promoted type of the two arguments. * @@ -632,7 +648,7 @@ complex log10(const complex& z); /* --- Power Functions --- */ /*! Returns a \p complex number raised to another. - * + * * The value types of the two \p complex types should be compatible and the * type of the returned \p complex is the promoted type of the two arguments. * @@ -739,7 +755,7 @@ complex tanh(const complex& z); /*! Returns the complex arc cosine of a \p complex number. * - * The range of the real part of the result is [0, Pi] and + * The range of the real part of the result is [0, Pi] and * the range of the imaginary part is [-inf, +inf] * * \param z The \p complex argument. @@ -750,7 +766,7 @@ complex acos(const complex& z); /*! Returns the complex arc sine of a \p complex number. * - * The range of the real part of the result is [-Pi/2, Pi/2] and + * The range of the real part of the result is [-Pi/2, Pi/2] and * the range of the imaginary part is [-inf, +inf] * * \param z The \p complex argument. @@ -761,7 +777,7 @@ complex asin(const complex& z); /*! Returns the complex arc tangent of a \p complex number. * - * The range of the real part of the result is [-Pi/2, Pi/2] and + * The range of the real part of the result is [-Pi/2, Pi/2] and * the range of the imaginary part is [-inf, +inf] * * \param z The \p complex argument. @@ -776,7 +792,7 @@ complex atan(const complex& z); /*! Returns the complex inverse hyperbolic cosine of a \p complex number. * - * The range of the real part of the result is [0, +inf] and + * The range of the real part of the result is [0, +inf] and * the range of the imaginary part is [-Pi, Pi] * * \param z The \p complex argument. @@ -787,7 +803,7 @@ complex acosh(const complex& z); /*! Returns the complex inverse hyperbolic sine of a \p complex number. * - * The range of the real part of the result is [-inf, +inf] and + * The range of the real part of the result is [-inf, +inf] and * the range of the imaginary part is [-Pi/2, Pi/2] * * \param z The \p complex argument. @@ -798,7 +814,7 @@ complex asinh(const complex& z); /*! Returns the complex inverse hyperbolic tangent of a \p complex number. * - * The range of the real part of the result is [-inf, +inf] and + * The range of the real part of the result is [-inf, +inf] and * the range of the imaginary part is [-Pi/2, Pi/2] * * \param z The \p complex argument. @@ -827,7 +843,7 @@ operator<<(std::basic_ostream& os, const complex& z); * - (real) * - (real, imaginary) * - * The values read must be convertible to the \p complex's \c value_type + * The values read must be convertible to the \p complex's \c value_type * * \param is The input stream. * \param z The \p complex number to set. @@ -856,7 +872,7 @@ bool operator==(const complex& x, const complex& y); * \param y The second \p complex. */ template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE bool operator==(const complex& x, const std::complex& y); /*! Returns true if two \p complex numbers are equal and false otherwise. @@ -865,7 +881,7 @@ bool operator==(const complex& x, const std::complex& y); * \param y The second \p complex. */ template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE bool operator==(const std::complex& x, const complex& y); /*! Returns true if the imaginary part of the \p complex number is zero and @@ -903,7 +919,7 @@ bool operator!=(const complex& x, const complex& y); * \param y The second \p complex. */ template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE bool operator!=(const complex& x, const std::complex& y); /*! Returns true if two \p complex numbers are different and false otherwise. @@ -912,7 +928,7 @@ bool operator!=(const complex& x, const std::complex& y); * \param y The second \p complex. */ template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE bool operator!=(const std::complex& x, const complex& y); /*! Returns true if the imaginary part of the \p complex number is not zero or @@ -939,6 +955,10 @@ bool operator!=(const complex& x, const T1& y); #include +#undef THRUST_STD_COMPLEX_REAL +#undef THRUST_STD_COMPLEX_IMAG +#undef THRUST_STD_COMPLEX_DEVICE + /*! \} // complex_numbers */ diff --git a/thrust/detail/complex/complex.inl b/thrust/detail/complex/complex.inl index f1726f948..632d91b49 100644 --- a/thrust/detail/complex/complex.inl +++ b/thrust/detail/complex/complex.inl @@ -104,35 +104,35 @@ complex::complex(const complex& z) #endif template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE complex::complex(const std::complex& z) #if THRUST_CPP_DIALECT >= 2011 // Initialize the storage in the member initializer list using C++ unicorn // initialization. This allows `complex` to work. - : data{z.real(), z.imag()} + : data{THRUST_STD_COMPLEX_REAL(z), THRUST_STD_COMPLEX_IMAG(z)} {} #else { - real(z.real()); - imag(z.imag()); -} + real(THRUST_STD_COMPLEX_REAL(z)); + imag(THRUST_STD_COMPLEX_IMAG(z)); +} #endif template -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE complex::complex(const std::complex& z) #if THRUST_CPP_DIALECT >= 2011 // Initialize the storage in the member initializer list using C++ unicorn // initialization. This allows `complex` to work. // We do a functional-style cast here to suppress conversion warnings. - : data{T(z.real()), T(z.imag())} + : data{T(THRUST_STD_COMPLEX_REAL(z)), T(THRUST_STD_COMPLEX_IMAG(z))} {} #else { - real(T(z.real())); - imag(T(z.imag())); -} + real(T(THRUST_STD_COMPLEX_REAL(z))); + imag(T(THRUST_STD_COMPLEX_IMAG(z))); +} #endif @@ -168,21 +168,21 @@ complex& complex::operator=(const complex& z) } template -__host__ +__host__ THRUST_STD_COMPLEX_DEVICE complex& complex::operator=(const std::complex& z) { - real(z.real()); - imag(z.imag()); + real(THRUST_STD_COMPLEX_REAL(z)); + imag(THRUST_STD_COMPLEX_IMAG(z)); return *this; } template -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE complex& complex::operator=(const std::complex& z) { - real(T(z.real())); - imag(T(z.imag())); + real(T(THRUST_STD_COMPLEX_REAL(z))); + imag(T(THRUST_STD_COMPLEX_IMAG(z))); return *this; } @@ -191,8 +191,8 @@ complex& complex::operator=(const std::complex& z) /* --- Compound Assignment Operators --- */ template -template -__host__ __device__ +template +__host__ __device__ complex& complex::operator+=(const complex& z) { *this = *this + z; @@ -200,7 +200,7 @@ complex& complex::operator+=(const complex& z) } template -template +template __host__ __device__ complex& complex::operator-=(const complex& z) { @@ -209,7 +209,7 @@ complex& complex::operator-=(const complex& z) } template -template +template __host__ __device__ complex& complex::operator*=(const complex& z) { @@ -218,7 +218,7 @@ complex& complex::operator*=(const complex& z) } template -template +template __host__ __device__ complex& complex::operator/=(const complex& z) { @@ -227,8 +227,8 @@ complex& complex::operator/=(const complex& z) } template -template -__host__ __device__ +template +__host__ __device__ complex& complex::operator+=(const U& z) { *this = *this + z; @@ -236,7 +236,7 @@ complex& complex::operator+=(const U& z) } template -template +template __host__ __device__ complex& complex::operator-=(const U& z) { @@ -245,7 +245,7 @@ complex& complex::operator-=(const U& z) } template -template +template __host__ __device__ complex& complex::operator*=(const U& z) { @@ -254,7 +254,7 @@ complex& complex::operator*=(const U& z) } template -template +template __host__ __device__ complex& complex::operator/=(const U& z) { @@ -266,70 +266,70 @@ complex& complex::operator/=(const U& z) /* --- Equality Operators --- */ -template +template __host__ __device__ bool operator==(const complex& x, const complex& y) { return x.real() == y.real() && x.imag() == y.imag(); } -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE bool operator==(const complex& x, const std::complex& y) { - return x.real() == y.real() && x.imag() == y.imag(); + return x.real() == THRUST_STD_COMPLEX_REAL(y) && x.imag() == THRUST_STD_COMPLEX_IMAG(y); } -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE bool operator==(const std::complex& x, const complex& y) { - return x.real() == y.real() && x.imag() == y.imag(); + return THRUST_STD_COMPLEX_REAL(x) == y.real() && THRUST_STD_COMPLEX_IMAG(x) == y.imag(); } -template +template __host__ __device__ bool operator==(const T0& x, const complex& y) { return x == y.real() && y.imag() == T1(); } -template +template __host__ __device__ bool operator==(const complex& x, const T1& y) { return x.real() == y && x.imag() == T1(); } -template +template __host__ __device__ bool operator!=(const complex& x, const complex& y) { return !(x == y); } -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE bool operator!=(const complex& x, const std::complex& y) { return !(x == y); } -template -__host__ +template +__host__ THRUST_STD_COMPLEX_DEVICE bool operator!=(const std::complex& x, const complex& y) { return !(x == y); } -template +template __host__ __device__ bool operator!=(const T0& x, const complex& y) { return !(x == y); } -template +template __host__ __device__ bool operator!=(const complex& x, const T1& y) {