Skip to content

Commit 8d49e31

Browse files
authored
Add chunksize to 3pt1Dopt solver (#80)
1 parent 22d2a05 commit 8d49e31

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

examples/heat_1d_ap_direct.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use nhls::direct_solver::*;
2-
use nhls::domain::*;
32
use nhls::image_1d_example::*;
43

54
fn main() {
@@ -8,8 +7,12 @@ fn main() {
87
let stencil = nhls::standard_stencils::heat_1d(1.0, 1.0, 0.5);
98

109
// Create BC
11-
let mut solver =
12-
Direct3Pt1DSolver::new(&stencil, args.steps_per_line, args.threads);
10+
let mut solver = Direct3Pt1DSolver::new(
11+
&stencil,
12+
args.steps_per_line,
13+
args.threads,
14+
args.chunk_size,
15+
);
1316

1417
args.run_solver(&mut solver);
1518
}

examples/heat_1d_ap_fft.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ fn main() {
88
let stencil = nhls::standard_stencils::heat_1d(1.0, 1.0, 0.5);
99

1010
// This optimized direct solver implement a uniform boundary condition of 0.0
11-
let direct_solver = DirectSolver3Pt1DOpt::new(&stencil);
11+
let direct_solver = DirectSolver3Pt1DOpt::new(&stencil, args.chunk_size);
1212

1313
// Create AP Solver
1414
let solver_params = args.solver_parameters();

examples/tv_heat_1d_ap_fft.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn main() {
77

88
let stencil = nhls::standard_stencils::TVHeat1D::new();
99

10-
let direct_solver = DirectSolver3Pt1DOpt::new(&stencil);
10+
let direct_solver = DirectSolver3Pt1DOpt::new(&stencil, args.chunk_size);
1111
// Create AP Solver
1212
let solver_params = args.solver_parameters();
1313
let mut solver =

src/direct_solver/direct_3pt1d_opt.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,21 @@ use crate::SolverInterface;
88
/// Implements a constant zero boundary condition.
99
pub struct DirectSolver3Pt1DOpt<'a, StencilType: TVStencil<1, 3>> {
1010
stencil: &'a StencilType,
11+
chunk_size: usize,
1112
}
1213

1314
impl<'a, StencilType: TVStencil<1, 3>> DirectSolver3Pt1DOpt<'a, StencilType> {
14-
pub fn new(stencil: &'a StencilType) -> Self {
15+
pub fn new(stencil: &'a StencilType, chunk_size: usize) -> Self {
1516
let expected_offsets = [
1617
vector![1], // 0
1718
vector![-1], // 1
1819
vector![0], // 4
1920
];
2021
assert_eq!(&expected_offsets, stencil.offsets());
21-
DirectSolver3Pt1DOpt { stencil }
22+
DirectSolver3Pt1DOpt {
23+
stencil,
24+
chunk_size,
25+
}
2226
}
2327

2428
fn apply_step<DomainType: DomainView<1> + Send>(
@@ -51,12 +55,12 @@ impl<'a, StencilType: TVStencil<1, 3>> DirectSolver3Pt1DOpt<'a, StencilType> {
5155

5256
let const_output: &DomainType = output;
5357
rayon::scope(|s| {
54-
profiling::scope!("direct_solver: Thread Callback");
55-
let chunk_size = (n_r - 2) / (threads * 2);
58+
let chunk_size = ((n_r - 2) / threads).max(self.chunk_size);
5659
let mut start: usize = 1;
5760
while start < n_r - 1 {
5861
let end = (start + chunk_size).min(n_r - 1);
5962
s.spawn(move |_| {
63+
profiling::scope!("direct_solver: Thread Callback");
6064
let mut o = const_output.unsafe_mut_access();
6165
for i in start..end {
6266
*o.buffer_mut().get_unchecked_mut(i) = w
@@ -104,9 +108,17 @@ pub struct Direct3Pt1DSolver<'a, StencilType: TVStencil<1, 3>> {
104108
}
105109

106110
impl<'a, StencilType: TVStencil<1, 3>> Direct3Pt1DSolver<'a, StencilType> {
107-
pub fn new(stencil: &'a StencilType, steps: usize, threads: usize) -> Self {
111+
pub fn new(
112+
stencil: &'a StencilType,
113+
steps: usize,
114+
threads: usize,
115+
chunk_size: usize,
116+
) -> Self {
108117
Direct3Pt1DSolver {
109-
solver: DirectSolver3Pt1DOpt { stencil },
118+
solver: DirectSolver3Pt1DOpt {
119+
stencil,
120+
chunk_size,
121+
},
110122
steps,
111123
threads,
112124
}

src/domain/view/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub trait DomainView<const GRID_DIMENSION: usize>: Sync {
8181
other: &DomainType,
8282
chunk_size: usize,
8383
) {
84+
profiling::scope!("domain::par_set_subdomain");
8485
let const_self_ref: &Self = self;
8586
other.buffer()[0..other.aabb().buffer_size()]
8687
.par_chunks(chunk_size)
@@ -109,6 +110,7 @@ pub trait DomainView<const GRID_DIMENSION: usize>: Sync {
109110
other: &DomainType,
110111
chunk_size: usize,
111112
) {
113+
profiling::scope!("domain::par_from_subdomain");
112114
self.par_set_values(|world_coord| other.view(&world_coord), chunk_size);
113115
}
114116

0 commit comments

Comments
 (0)