Skip to content

Commit be5919c

Browse files
yasahi-hpcYuuichi Asahi
andauthored
Apply check functions to fft functions (#130)
* Improve assertions in fft functions * format * use is_complex_v * fix: typo * using string_view and remove maybe_unused from assertion helper --------- Co-authored-by: Yuuichi Asahi <[email protected]>
1 parent 0dfbae8 commit be5919c

11 files changed

+189
-163
lines changed

common/src/KokkosFFT_asserts.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
4+
5+
#ifndef KOKKOSFFT_ASSERTS_HPP
6+
#define KOKKOSFFT_ASSERTS_HPP
7+
8+
#include <stdexcept>
9+
#include <sstream>
10+
#include <string_view>
11+
12+
#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L
13+
#include <source_location>
14+
#define KOKKOSFFT_EXPECTS(expression, msg) \
15+
KokkosFFT::Impl::check_precondition( \
16+
(expression), msg, std::source_location::current().file_name(), \
17+
std::source_location::current().line(), \
18+
std::source_location::current().function_name(), \
19+
std::source_location::current().column())
20+
#else
21+
#include <cstdlib>
22+
#define KOKKOSFFT_EXPECTS(expression, msg) \
23+
KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \
24+
__FUNCTION__)
25+
#endif
26+
27+
namespace KokkosFFT {
28+
namespace Impl {
29+
30+
inline void check_precondition(const bool expression,
31+
const std::string_view& msg,
32+
const char* file_name, int line,
33+
const char* function_name,
34+
const int column = -1) {
35+
// Quick return if possible
36+
if (expression) return;
37+
38+
std::stringstream ss("file: ");
39+
if (column == -1) {
40+
// For C++ 17
41+
ss << file_name << '(' << line << ") `" << function_name << "`: " << msg
42+
<< '\n';
43+
} else {
44+
// For C++ 20 and later
45+
ss << file_name << '(' << line << ':' << column << ") `" << function_name
46+
<< "`: " << msg << '\n';
47+
}
48+
throw std::runtime_error(ss.str());
49+
}
50+
51+
} // namespace Impl
52+
} // namespace KokkosFFT
53+
54+
#endif

common/src/KokkosFFT_utils.hpp

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,13 @@
1010
#include <set>
1111
#include <algorithm>
1212
#include <numeric>
13+
#include "KokkosFFT_asserts.hpp"
1314
#include "KokkosFFT_traits.hpp"
1415
#include "KokkosFFT_common_types.hpp"
1516

16-
#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L
17-
#include <source_location>
18-
#define KOKKOSFFT_EXPECTS(expression, msg) \
19-
KokkosFFT::Impl::check_precondition( \
20-
(expression), msg, std::source_location::current().file_name(), \
21-
std::source_location::current().line(), \
22-
std::source_location::current().function_name(), \
23-
std::source_location::current().column())
24-
#else
25-
#include <cstdlib>
26-
#define KOKKOSFFT_EXPECTS(expression, msg) \
27-
KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \
28-
__FUNCTION__)
29-
#endif
30-
3117
namespace KokkosFFT {
3218
namespace Impl {
3319

34-
inline void check_precondition(const bool expression,
35-
[[maybe_unused]] const std::string& msg,
36-
[[maybe_unused]] const char* file_name, int line,
37-
[[maybe_unused]] const char* function_name,
38-
[[maybe_unused]] const int column = -1) {
39-
// Quick return if possible
40-
if (expression) return;
41-
42-
std::stringstream ss("file: ");
43-
if (column == -1) {
44-
// For C++ 17
45-
ss << file_name << '(' << line << ") `" << function_name << "`: " << msg
46-
<< '\n';
47-
} else {
48-
// For C++ 20 and later
49-
ss << file_name << '(' << line << ':' << column << ") `" << function_name
50-
<< "`: " << msg << '\n';
51-
}
52-
throw std::runtime_error(ss.str());
53-
}
54-
5520
template <typename ViewType>
5621
auto convert_negative_axis(ViewType, int _axis = -1) {
5722
static_assert(Kokkos::is_view_v<ViewType>,

fft/src/KokkosFFT_Cuda_plans.hpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <numeric>
99
#include "KokkosFFT_Cuda_types.hpp"
1010
#include "KokkosFFT_layouts.hpp"
11+
#include "KokkosFFT_asserts.hpp"
1112

1213
namespace KokkosFFT {
1314
namespace Impl {
@@ -30,7 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space,
3031

3132
plan = std::make_unique<PlanType>();
3233
cufftResult cufft_rt = cufftCreate(&(*plan));
33-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");
34+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
3435

3536
cudaStream_t stream = exec_space.cuda_stream();
3637
cufftSetStream((*plan), stream);
@@ -44,7 +45,8 @@ auto create_plan(const ExecutionSpace& exec_space,
4445
std::multiplies<>());
4546

4647
cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany);
47-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan1d failed");
48+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan1d failed");
49+
4850
return fft_size;
4951
}
5052

@@ -67,7 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space,
6769

6870
plan = std::make_unique<PlanType>();
6971
cufftResult cufft_rt = cufftCreate(&(*plan));
70-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");
72+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
7173

7274
cudaStream_t stream = exec_space.cuda_stream();
7375
cufftSetStream((*plan), stream);
@@ -81,7 +83,8 @@ auto create_plan(const ExecutionSpace& exec_space,
8183
std::multiplies<>());
8284

8385
cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
84-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan2d failed");
86+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan2d failed");
87+
8588
return fft_size;
8689
}
8790

@@ -104,7 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space,
104107

105108
plan = std::make_unique<PlanType>();
106109
cufftResult cufft_rt = cufftCreate(&(*plan));
107-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");
110+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
108111

109112
cudaStream_t stream = exec_space.cuda_stream();
110113
cufftSetStream((*plan), stream);
@@ -120,7 +123,8 @@ auto create_plan(const ExecutionSpace& exec_space,
120123
std::multiplies<>());
121124

122125
cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
123-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftPlan3d failed");
126+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan3d failed");
127+
124128
return fft_size;
125129
}
126130

@@ -163,16 +167,16 @@ auto create_plan(const ExecutionSpace& exec_space,
163167

164168
plan = std::make_unique<PlanType>();
165169
cufftResult cufft_rt = cufftCreate(&(*plan));
166-
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");
170+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
167171

168172
cudaStream_t stream = exec_space.cuda_stream();
169173
cufftSetStream((*plan), stream);
170174

171175
cufft_rt = cufftPlanMany(&(*plan), rank, fft_extents.data(),
172176
in_extents.data(), istride, idist,
173177
out_extents.data(), ostride, odist, type, howmany);
174-
if (cufft_rt != CUFFT_SUCCESS)
175-
throw std::runtime_error("cufftPlanMany failed");
178+
179+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlanMany failed");
176180

177181
return fft_size;
178182
}
@@ -186,4 +190,4 @@ void destroy_plan_and_info(std::unique_ptr<PlanType>& plan, InfoType&) {
186190
} // namespace Impl
187191
} // namespace KokkosFFT
188192

189-
#endif
193+
#endif

fft/src/KokkosFFT_Cuda_transform.hpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,52 @@
66
#define KOKKOSFFT_CUDA_TRANSFORM_HPP
77

88
#include <cufft.h>
9+
#include "KokkosFFT_asserts.hpp"
910

1011
namespace KokkosFFT {
1112
namespace Impl {
1213
template <typename... Args>
1314
inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
1415
int /*direction*/, Args...) {
1516
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);
16-
if (cufft_rt != CUFFT_SUCCESS)
17-
throw std::runtime_error("cufftExecR2C failed");
17+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecR2C failed");
1818
}
1919

2020
template <typename... Args>
2121
inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata,
2222
cufftDoubleComplex* odata, int /*direction*/, Args...) {
2323
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
24-
if (cufft_rt != CUFFT_SUCCESS)
25-
throw std::runtime_error("cufftExecD2Z failed");
24+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecD2Z failed");
2625
}
2726

2827
template <typename... Args>
2928
inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
3029
int /*direction*/, Args...) {
3130
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
32-
if (cufft_rt != CUFFT_SUCCESS)
33-
throw std::runtime_error("cufftExecC2R failed");
31+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2R failed");
3432
}
3533

3634
template <typename... Args>
3735
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
3836
cufftDoubleReal* odata, int /*direction*/, Args...) {
3937
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
40-
if (cufft_rt != CUFFT_SUCCESS)
41-
throw std::runtime_error("cufftExecZ2D failed");
38+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2D failed");
4239
}
4340

4441
template <typename... Args>
4542
inline void exec_plan(cufftHandle& plan, cufftComplex* idata,
4643
cufftComplex* odata, int direction, Args...) {
4744
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
48-
if (cufft_rt != CUFFT_SUCCESS)
49-
throw std::runtime_error("cufftExecC2C failed");
45+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2C failed");
5046
}
5147

5248
template <typename... Args>
5349
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
5450
cufftDoubleComplex* odata, int direction, Args...) {
5551
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
56-
if (cufft_rt != CUFFT_SUCCESS)
57-
throw std::runtime_error("cufftExecZ2Z failed");
52+
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2Z failed");
5853
}
5954
} // namespace Impl
6055
} // namespace KokkosFFT
6156

62-
#endif
57+
#endif

fft/src/KokkosFFT_HIP_plans.hpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <numeric>
99
#include "KokkosFFT_HIP_types.hpp"
1010
#include "KokkosFFT_layouts.hpp"
11+
#include "KokkosFFT_asserts.hpp"
1112

1213
namespace KokkosFFT {
1314
namespace Impl {
@@ -30,8 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space,
3031

3132
plan = std::make_unique<PlanType>();
3233
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
33-
if (hipfft_rt != HIPFFT_SUCCESS)
34-
throw std::runtime_error("hipfftCreate failed");
34+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed");
3535

3636
hipStream_t stream = exec_space.hip_stream();
3737
hipfftSetStream((*plan), stream);
@@ -45,8 +45,8 @@ auto create_plan(const ExecutionSpace& exec_space,
4545
std::multiplies<>());
4646

4747
hipfft_rt = hipfftPlan1d(&(*plan), nx, type, howmany);
48-
if (hipfft_rt != HIPFFT_SUCCESS)
49-
throw std::runtime_error("hipfftPlan1d failed");
48+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan1d failed");
49+
5050
return fft_size;
5151
}
5252

@@ -69,8 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space,
6969

7070
plan = std::make_unique<PlanType>();
7171
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
72-
if (hipfft_rt != HIPFFT_SUCCESS)
73-
throw std::runtime_error("hipfftCreate failed");
72+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed");
7473

7574
hipStream_t stream = exec_space.hip_stream();
7675
hipfftSetStream((*plan), stream);
@@ -84,8 +83,8 @@ auto create_plan(const ExecutionSpace& exec_space,
8483
std::multiplies<>());
8584

8685
hipfft_rt = hipfftPlan2d(&(*plan), nx, ny, type);
87-
if (hipfft_rt != HIPFFT_SUCCESS)
88-
throw std::runtime_error("hipfftPlan2d failed");
86+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan2d failed");
87+
8988
return fft_size;
9089
}
9190

@@ -108,8 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space,
108107

109108
plan = std::make_unique<PlanType>();
110109
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
111-
if (hipfft_rt != HIPFFT_SUCCESS)
112-
throw std::runtime_error("hipfftCreate failed");
110+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed");
113111

114112
hipStream_t stream = exec_space.hip_stream();
115113
hipfftSetStream((*plan), stream);
@@ -125,8 +123,8 @@ auto create_plan(const ExecutionSpace& exec_space,
125123
std::multiplies<>());
126124

127125
hipfft_rt = hipfftPlan3d(&(*plan), nx, ny, nz, type);
128-
if (hipfft_rt != HIPFFT_SUCCESS)
129-
throw std::runtime_error("hipfftPlan3d failed");
126+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlan3d failed");
127+
130128
return fft_size;
131129
}
132130

@@ -169,8 +167,7 @@ auto create_plan(const ExecutionSpace& exec_space,
169167

170168
plan = std::make_unique<PlanType>();
171169
hipfftResult hipfft_rt = hipfftCreate(&(*plan));
172-
if (hipfft_rt != HIPFFT_SUCCESS)
173-
throw std::runtime_error("hipfftCreate failed");
170+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftCreate failed");
174171

175172
hipStream_t stream = exec_space.hip_stream();
176173
hipfftSetStream((*plan), stream);
@@ -179,8 +176,8 @@ auto create_plan(const ExecutionSpace& exec_space,
179176
in_extents.data(), istride, idist,
180177
out_extents.data(), ostride, odist, type, howmany);
181178

182-
if (hipfft_rt != HIPFFT_SUCCESS)
183-
throw std::runtime_error("hipfftPlan failed");
179+
KOKKOSFFT_EXPECTS(hipfft_rt == HIPFFT_SUCCESS, "hipfftPlanMany failed");
180+
184181
return fft_size;
185182
}
186183

@@ -193,4 +190,4 @@ void destroy_plan_and_info(std::unique_ptr<PlanType>& plan, InfoType&) {
193190
} // namespace Impl
194191
} // namespace KokkosFFT
195192

196-
#endif
193+
#endif

0 commit comments

Comments
 (0)