Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Make thrust::/std::complex interop __device__ qualified for C++11+.
Browse files Browse the repository at this point in the history
The functions to construct, assign and compare thrust::complex values
from and with std::complex values were marked __host__ since forever,
because access to std::complex is performed using member functions.
However, in C++11, an explicit permission has been given to
reinterpret_cast std::complex values as arrays of two elements of its
template parameter, allowing us to implement a __device__-compatible set
of those interop functions, when compiling for C++11.

For C++03, they are still only __host__-qualified.

Bug 2502854
  • Loading branch information
griwes committed Mar 25, 2019
1 parent 4f43a17 commit b14187c
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 82 deletions.
24 changes: 24 additions & 0 deletions testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,27 @@ struct TestComplexStreamOperators
};

SimpleUnitTest<TestComplexStreamOperators, FloatingPointTypes> TestComplexStreamOperatorsInstance;

#if THRUST_CPP_DIALECT >= 2011
template<typename T>
struct TestComplexStdComplexDeviceInterop
{
void operator()()
{
thrust::host_vector<T> data = unittest::random_samples<T>(6);
std::vector<std::complex<T> > vec(10);
vec[0] = std::complex<T>(data[0], data[1]);
vec[1] = std::complex<T>(data[2], data[3]);
vec[2] = std::complex<T>(data[4], data[5]);

thrust::device_vector<thrust::complex<T> > device_vec = vec;
ASSERT_ALMOST_EQUAL(vec[0].real(), thrust::complex<T>(device_vec[0]).real());
ASSERT_ALMOST_EQUAL(vec[0].imag(), thrust::complex<T>(device_vec[0]).imag());
ASSERT_ALMOST_EQUAL(vec[1].real(), thrust::complex<T>(device_vec[1]).real());
ASSERT_ALMOST_EQUAL(vec[1].imag(), thrust::complex<T>(device_vec[1]).imag());
ASSERT_ALMOST_EQUAL(vec[2].real(), thrust::complex<T>(device_vec[2]).real());
ASSERT_ALMOST_EQUAL(vec[2].imag(), thrust::complex<T>(device_vec[2]).imag());
}
};
SimpleUnitTest<TestComplexStdComplexDeviceInterop, FloatingPointTypes> TestComplexStdComplexDeviceInteropInstance;
#endif
96 changes: 58 additions & 38 deletions thrust/complex.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2008-2018 NVIDIA Corporation
* Copyright 2008-2019 NVIDIA Corporation
* Copyright 2013 Filipe RNC Maia
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -28,11 +28,27 @@
#include <sstream>
#include <thrust/detail/type_traits.h>

#if THRUST_CPP_DIALECT >= 2011
# define THRUST_STD_COMPLEX_REAL(z) \
reinterpret_cast< \
const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
>(z)[0]
# define THRUST_STD_COMPLEX_IMAG(z) \
reinterpret_cast< \
const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
>(z)[1]
# define THRUST_STD_COMPLEX_DEVICE __device__
#else
# define THRUST_STD_COMPLEX_REAL(z) (z).real()
# define THRUST_STD_COMPLEX_IMAG(z) (z).imag()
# define THRUST_STD_COMPLEX_DEVICE
#endif

namespace thrust
{

/*
* Calls to the standard math library from inside the thrust namespace
* Calls to the standard math library from inside the thrust namespace
* with real arguments require explicit scope otherwise they will fail
* to resolve as it will find the equivalent complex function but then
* fail to match the template, and give up looking for other scopes.
Expand Down Expand Up @@ -112,7 +128,7 @@ struct complex
*
* \param z The \p complex to copy from.
*/
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
complex(const std::complex<T>& z);

/*! This converting copy constructor copies from a <tt>std::complex</tt> with
Expand All @@ -122,8 +138,8 @@ struct complex
*
* \tparam U is convertible to \c value_type.
*/
template <typename U>
__host__
template <typename U>
__host__ THRUST_STD_COMPLEX_DEVICE
complex(const std::complex<U>& z);


Expand Down Expand Up @@ -162,7 +178,7 @@ struct complex
*
* \param z The \p complex to copy from.
*/
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
complex& operator=(const std::complex<T>& z);

/*! Assign `z.real()` and `z.imag()` to the real and imaginary parts of this
Expand All @@ -172,8 +188,8 @@ struct complex
*
* \tparam U is convertible to \c value_type.
*/
template <typename U>
__host__
template <typename U>
__host__ THRUST_STD_COMPLEX_DEVICE
complex& operator=(const std::complex<U>& z);


Expand All @@ -184,7 +200,7 @@ struct complex
* \p complex.
*
* \param z The \p complex to be added.
*
*
* \tparam U is convertible to \c value_type.
*/
template <typename U>
Expand Down Expand Up @@ -248,7 +264,7 @@ struct complex

/*! Multiplies this \p complex by a scalar and assigns the result
* to this \p complex.
*
*
* \param z The scalar to be multiplied.
*
* \tparam U is convertible to \c value_type.
Expand All @@ -259,7 +275,7 @@ struct complex

/*! Divides this \p complex by a scalar and assigns the result to
* this \p complex.
*
*
* \param z The scalar to be divided.
*
* \tparam U is convertible to \c value_type.
Expand All @@ -270,7 +286,7 @@ struct complex



/* --- Getter functions ---
/* --- Getter functions ---
* The volatile ones are there to help for example
* with certain reductions optimizations
*/
Expand All @@ -297,7 +313,7 @@ struct complex



/* --- Setter functions ---
/* --- Setter functions ---
* The volatile ones are there to help for example
* with certain reductions optimizations
*/
Expand Down Expand Up @@ -409,8 +425,8 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
polar(const T0& m, const T1& theta = T1());

/*! Returns the projection of a \p complex on the Riemann sphere.
* For all finite \p complex it returns the argument. For \p complexs
* with a non finite part returns (INFINITY,+/-0) where the sign of
* For all finite \p complex it returns the argument. For \p complexs
* with a non finite part returns (INFINITY,+/-0) where the sign of
* the zero matches the sign of the imaginary part of the argument.
*
* \param z The \p complex argument.
Expand All @@ -424,7 +440,7 @@ complex<T> proj(const T& z);
/* --- Binary Arithmetic operators --- */

/*! Adds two \p complex numbers.
*
*
* The value types of the two \p complex types should be compatible and the
* type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -437,7 +453,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const complex<T0>& x, const complex<T1>& y);

/*! Adds a scalar to a \p complex number.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -450,7 +466,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const complex<T0>& x, const T1& y);

/*! Adds a \p complex number to a scalar.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -463,7 +479,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const T0& x, const complex<T1>& y);

/*! Subtracts two \p complex numbers.
*
*
* The value types of the two \p complex types should be compatible and the
* type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -476,7 +492,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const complex<T0>& x, const complex<T1>& y);

/*! Subtracts a scalar from a \p complex number.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -489,7 +505,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const complex<T0>& x, const T1& y);

/*! Subtracts a \p complex number from a scalar.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -502,7 +518,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const T0& x, const complex<T1>& y);

/*! Multiplies two \p complex numbers.
*
*
* The value types of the two \p complex types should be compatible and the
* type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -525,7 +541,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator*(const complex<T0>& x, const T1& y);

/*! Multiplies a scalar by a \p complex number.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -538,7 +554,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator*(const T0& x, const complex<T1>& y);

/*! Divides two \p complex numbers.
*
*
* The value types of the two \p complex types should be compatible and the
* type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -551,7 +567,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator/(const complex<T0>& x, const complex<T1>& y);

/*! Divides a \p complex number by a scalar.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand All @@ -564,7 +580,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator/(const complex<T0>& x, const T1& y);

/*! Divides a scalar by a \p complex number.
*
*
* The value type of the \p complex should be compatible with the scalar and
* the type of the returned \p complex is the promoted type of the two arguments.
*
Expand Down Expand Up @@ -632,7 +648,7 @@ complex<T> log10(const complex<T>& z);
/* --- Power Functions --- */

/*! Returns a \p complex number raised to another.
*
*
* The value types of the two \p complex types should be compatible and the
* type of the returned \p complex is the promoted type of the two arguments.
*
Expand Down Expand Up @@ -739,7 +755,7 @@ complex<T> tanh(const complex<T>& z);

/*! Returns the complex arc cosine of a \p complex number.
*
* The range of the real part of the result is [0, Pi] and
* The range of the real part of the result is [0, Pi] and
* the range of the imaginary part is [-inf, +inf]
*
* \param z The \p complex argument.
Expand All @@ -750,7 +766,7 @@ complex<T> acos(const complex<T>& z);

/*! Returns the complex arc sine of a \p complex number.
*
* The range of the real part of the result is [-Pi/2, Pi/2] and
* The range of the real part of the result is [-Pi/2, Pi/2] and
* the range of the imaginary part is [-inf, +inf]
*
* \param z The \p complex argument.
Expand All @@ -761,7 +777,7 @@ complex<T> asin(const complex<T>& z);

/*! Returns the complex arc tangent of a \p complex number.
*
* The range of the real part of the result is [-Pi/2, Pi/2] and
* The range of the real part of the result is [-Pi/2, Pi/2] and
* the range of the imaginary part is [-inf, +inf]
*
* \param z The \p complex argument.
Expand All @@ -776,7 +792,7 @@ complex<T> atan(const complex<T>& z);

/*! Returns the complex inverse hyperbolic cosine of a \p complex number.
*
* The range of the real part of the result is [0, +inf] and
* The range of the real part of the result is [0, +inf] and
* the range of the imaginary part is [-Pi, Pi]
*
* \param z The \p complex argument.
Expand All @@ -787,7 +803,7 @@ complex<T> acosh(const complex<T>& z);

/*! Returns the complex inverse hyperbolic sine of a \p complex number.
*
* The range of the real part of the result is [-inf, +inf] and
* The range of the real part of the result is [-inf, +inf] and
* the range of the imaginary part is [-Pi/2, Pi/2]
*
* \param z The \p complex argument.
Expand All @@ -798,7 +814,7 @@ complex<T> asinh(const complex<T>& z);

/*! Returns the complex inverse hyperbolic tangent of a \p complex number.
*
* The range of the real part of the result is [-inf, +inf] and
* The range of the real part of the result is [-inf, +inf] and
* the range of the imaginary part is [-Pi/2, Pi/2]
*
* \param z The \p complex argument.
Expand Down Expand Up @@ -827,7 +843,7 @@ operator<<(std::basic_ostream<CharT, Traits>& os, const complex<T>& z);
* - (real)
* - (real, imaginary)
*
* The values read must be convertible to the \p complex's \c value_type
* The values read must be convertible to the \p complex's \c value_type
*
* \param is The input stream.
* \param z The \p complex number to set.
Expand Down Expand Up @@ -856,7 +872,7 @@ bool operator==(const complex<T0>& x, const complex<T1>& y);
* \param y The second \p complex.
*/
template <typename T0, typename T1>
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator==(const complex<T0>& x, const std::complex<T1>& y);

/*! Returns true if two \p complex numbers are equal and false otherwise.
Expand All @@ -865,7 +881,7 @@ bool operator==(const complex<T0>& x, const std::complex<T1>& y);
* \param y The second \p complex.
*/
template <typename T0, typename T1>
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator==(const std::complex<T0>& x, const complex<T1>& y);

/*! Returns true if the imaginary part of the \p complex number is zero and
Expand Down Expand Up @@ -903,7 +919,7 @@ bool operator!=(const complex<T0>& x, const complex<T1>& y);
* \param y The second \p complex.
*/
template <typename T0, typename T1>
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator!=(const complex<T0>& x, const std::complex<T1>& y);

/*! Returns true if two \p complex numbers are different and false otherwise.
Expand All @@ -912,7 +928,7 @@ bool operator!=(const complex<T0>& x, const std::complex<T1>& y);
* \param y The second \p complex.
*/
template <typename T0, typename T1>
__host__
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator!=(const std::complex<T0>& x, const complex<T1>& y);

/*! Returns true if the imaginary part of the \p complex number is not zero or
Expand All @@ -939,6 +955,10 @@ bool operator!=(const complex<T0>& x, const T1& y);

#include <thrust/detail/complex/complex.inl>

#undef THRUST_STD_COMPLEX_REAL
#undef THRUST_STD_COMPLEX_IMAG
#undef THRUST_STD_COMPLEX_DEVICE

/*! \} // complex_numbers
*/

Expand Down
Loading

0 comments on commit b14187c

Please sign in to comment.