Skip to content

Commit 2967b47

Browse files
committed
Add function requires_update
1 parent 550ccd0 commit 2967b47

File tree

5 files changed

+56
-32
lines changed

5 files changed

+56
-32
lines changed

src/PointNeighbors.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export foreach_point_neighbor, foreach_neighbor
2424
export TrivialNeighborhoodSearch, GridNeighborhoodSearch, PrecomputedNeighborhoodSearch
2525
export DictionaryCellList, FullGridCellList
2626
export ParallelUpdate, SemiParallelUpdate, SerialUpdate
27-
export initialize!, update!, initialize_grid!, update_grid!
27+
export requires_update, initialize!, update!, initialize_grid!, update_grid!
2828
export PolyesterBackend, ThreadsDynamicBackend, ThreadsStaticBackend
2929
export PeriodicBox, copy_neighborhood_search
3030

src/neighborhood_search.jl

+33-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ abstract type AbstractNeighborhoodSearch end
22

33
@inline search_radius(search::AbstractNeighborhoodSearch) = search.search_radius
44

5+
"""
6+
requires_update(search::AbstractNeighborhoodSearch)
7+
8+
Returns a tuple `(x_changed, y_changed)` indicating if this type of neighborhood search
9+
requires an update when the coordinates of the points in `x` or `y` change.
10+
"""
11+
function requires_update(::AbstractNeighborhoodSearch)
12+
error("`requires_update` not implemented for this neighborhood search.")
13+
end
14+
515
"""
616
initialize!(search::AbstractNeighborhoodSearch, x, y)
717
@@ -206,13 +216,32 @@ end
206216
return nothing
207217
end
208218

209-
@inline function foreach_neighbor(f, system_coords, neighbor_system_coords,
210-
neighborhood_search, point;
211-
search_radius = search_radius(neighborhood_search))
219+
@propagate_inbounds function foreach_neighbor(f, system_coords, neighbor_system_coords,
220+
neighborhood_search::AbstractNeighborhoodSearch,
221+
point;
222+
search_radius = search_radius(neighborhood_search))
223+
# Due to https://github.com/JuliaLang/julia/issues/30411, we cannot just remove
224+
# a `@boundscheck` by calling this function with `@inbounds` because it has a kwarg.
225+
# We have to use `@propagate_inbounds`, which will also remove boundschecks
226+
# in the neighbor loop, which is not safe (see comment below).
227+
# To avoid this, we have to use a function barrier to disable the `@inbounds` again.
228+
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
229+
230+
foreach_neighbor(f, neighbor_system_coords, neighborhood_search,
231+
point, point_coords, search_radius)
232+
end
233+
234+
# This is the generic function that is called for `TrivialNeighborhoodSearch`.
235+
# For `GridNeighborhoodSearch`, a specialized function is used for slightly better
236+
# performance. `PrecomputedNeighborhoodSearch` can skip the distance check altogether.
237+
@inline function foreach_neighbor(f, neighbor_system_coords,
238+
neighborhood_search::AbstractNeighborhoodSearch,
239+
point, point_coords, search_radius)
212240
(; periodic_box) = neighborhood_search
213241

214-
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
215242
for neighbor in eachneighbor(point_coords, neighborhood_search)
243+
# Making the following `@inbounds` yields a ~2% speedup on an NVIDIA H100.
244+
# But we don't know if `neighbor` (extracted from the cell list) is in bounds.
216245
neighbor_coords = extract_svector(neighbor_system_coords,
217246
Val(ndims(neighborhood_search)), neighbor)
218247

src/nhs_grid.jl

+10-22
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ function GridNeighborhoodSearch{NDIMS}(; search_radius = 0.0, n_points = 0,
103103
cell_size, update_buffer, update_strategy)
104104
end
105105

106+
@inline Base.ndims(::GridNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
107+
108+
@inline requires_update(::GridNeighborhoodSearch) = (false, true)
109+
106110
"""
107111
ParallelUpdate()
108112
@@ -158,8 +162,6 @@ end
158162
push!(update_buffer, index_type(cell_list)[])
159163
end
160164

161-
@inline Base.ndims(::GridNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
162-
163165
function initialize!(neighborhood_search::GridNeighborhoodSearch,
164166
x::AbstractMatrix, y::AbstractMatrix)
165167
initialize_grid!(neighborhood_search, y)
@@ -355,24 +357,11 @@ function update_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any, Paralle
355357
return neighborhood_search
356358
end
357359

358-
@propagate_inbounds function foreach_neighbor(f, system_coords, neighbor_system_coords,
359-
neighborhood_search::GridNeighborhoodSearch,
360-
point;
361-
search_radius = search_radius(neighborhood_search))
362-
# Due to https://github.com/JuliaLang/julia/issues/30411, we cannot just remove
363-
# a `@boundscheck` by calling this function with `@inbounds` because it has a kwarg.
364-
# We have to use `@propagate_inbounds`, which will also remove boundschecks
365-
# in the neighbor loop, which is not safe (see comment below).
366-
# To avoid this, we have to use a function barrier to disable the `@inbounds` again.
367-
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
368-
369-
__foreach_neighbor(f, system_coords, neighbor_system_coords, neighborhood_search,
370-
point, point_coords, search_radius)
371-
end
372-
373-
@inline function __foreach_neighbor(f, system_coords, neighbor_system_coords,
374-
neighborhood_search::GridNeighborhoodSearch,
375-
point, point_coords, search_radius)
360+
# Specialized version of the function in `neighborhood_search.jl`, which is faster
361+
# than looping over `eachneighbor`.
362+
@inline function foreach_neighbor(f, neighbor_system_coords,
363+
neighborhood_search::GridNeighborhoodSearch,
364+
point, point_coords, search_radius)
376365
(; periodic_box) = neighborhood_search
377366

378367
cell = cell_coords(point_coords, neighborhood_search)
@@ -393,8 +382,7 @@ end
393382
distance2 = dot(pos_diff, pos_diff)
394383

395384
pos_diff, distance2 = compute_periodic_distance(pos_diff, distance2,
396-
search_radius,
397-
periodic_box)
385+
search_radius, periodic_box)
398386

399387
if distance2 <= search_radius^2
400388
distance = sqrt(distance2)

src/nhs_precomputed.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ end
4545

4646
@inline Base.ndims(::PrecomputedNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
4747

48+
@inline requires_update(::PrecomputedNeighborhoodSearch) = (true, true)
49+
4850
@inline function search_radius(search::PrecomputedNeighborhoodSearch)
4951
return search_radius(search.neighborhood_search)
5052
end
@@ -92,14 +94,17 @@ function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y)
9294
end
9395
end
9496

95-
@inline function foreach_neighbor(f, system_coords, neighbor_system_coords,
97+
@inline function foreach_neighbor(f, neighbor_system_coords,
9698
neighborhood_search::PrecomputedNeighborhoodSearch,
97-
point; search_radius = nothing)
99+
point, point_coords, search_radius)
98100
(; periodic_box, neighbor_lists) = neighborhood_search
99-
(; search_radius) = neighborhood_search.neighborhood_search
100101

101-
point_coords = extract_svector(system_coords, Val(ndims(neighborhood_search)), point)
102-
for neighbor in neighbor_lists[point]
102+
neighbors = @inbounds neighbor_lists[point]
103+
for neighbor_ in eachindex(neighbors)
104+
neighbor = @inbounds neighbors[neighbor_]
105+
106+
# Making the following `@inbounds` yields a ~2% speedup on an NVIDIA H100.
107+
# But we don't know if `neighbor` (extracted from the cell list) is in bounds.
103108
neighbor_coords = extract_svector(neighbor_system_coords,
104109
Val(ndims(neighborhood_search)), neighbor)
105110

src/nhs_trivial.jl

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ end
3030

3131
@inline Base.ndims(::TrivialNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
3232

33+
@inline requires_update(::TrivialNeighborhoodSearch) = (false, false)
34+
3335
@inline initialize!(search::TrivialNeighborhoodSearch, x, y) = search
3436

3537
@inline function update!(search::TrivialNeighborhoodSearch, x, y;

0 commit comments

Comments
 (0)