Skip to content

Commit 92acc07

Browse files
authored
Fix A* implementation (JuliaGraphs#125)
* Fix A* with additional argument edgetype_to_return * Concrete key type in A* priority queue * Add back closed_set and type restriction heuristic::Function
1 parent 5608e49 commit 92acc07

File tree

5 files changed

+278
-26
lines changed

5 files changed

+278
-26
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ docs/site/
77
benchmark/.results/*
88
benchmark/.tune.jld
99
*.cov
10-
Manifest.toml
10+
/Manifest.toml

Diff for: docs/Manifest.toml

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.7.2"
4+
manifest_format = "2.0"
5+
6+
[[deps.ArgTools]]
7+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
8+
9+
[[deps.ArnoldiMethod]]
10+
deps = ["LinearAlgebra", "Random", "StaticArrays"]
11+
git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae"
12+
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
13+
version = "0.2.0"
14+
15+
[[deps.Artifacts]]
16+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
17+
18+
[[deps.Base64]]
19+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
20+
21+
[[deps.Compat]]
22+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
23+
git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b"
24+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
25+
version = "3.43.0"
26+
27+
[[deps.CompilerSupportLibraries_jll]]
28+
deps = ["Artifacts", "Libdl"]
29+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
30+
31+
[[deps.DataStructures]]
32+
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
33+
git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75"
34+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
35+
version = "0.18.12"
36+
37+
[[deps.Dates]]
38+
deps = ["Printf"]
39+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
40+
41+
[[deps.DelimitedFiles]]
42+
deps = ["Mmap"]
43+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
44+
45+
[[deps.Distributed]]
46+
deps = ["Random", "Serialization", "Sockets"]
47+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
48+
49+
[[deps.DocStringExtensions]]
50+
deps = ["LibGit2"]
51+
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
52+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
53+
version = "0.8.6"
54+
55+
[[deps.Documenter]]
56+
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
57+
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
58+
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
59+
version = "0.26.3"
60+
61+
[[deps.Downloads]]
62+
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
63+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
64+
65+
[[deps.Graphs]]
66+
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
67+
path = ".."
68+
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
69+
version = "1.6.0"
70+
71+
[[deps.IOCapture]]
72+
deps = ["Logging"]
73+
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
74+
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
75+
version = "0.1.1"
76+
77+
[[deps.Inflate]]
78+
git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c"
79+
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
80+
version = "0.1.2"
81+
82+
[[deps.InteractiveUtils]]
83+
deps = ["Markdown"]
84+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
85+
86+
[[deps.JSON]]
87+
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
88+
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
89+
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
90+
version = "0.21.3"
91+
92+
[[deps.LibCURL]]
93+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
94+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
95+
96+
[[deps.LibCURL_jll]]
97+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
98+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
99+
100+
[[deps.LibGit2]]
101+
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
102+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
103+
104+
[[deps.LibSSH2_jll]]
105+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
106+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
107+
108+
[[deps.Libdl]]
109+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
110+
111+
[[deps.LinearAlgebra]]
112+
deps = ["Libdl", "libblastrampoline_jll"]
113+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
114+
115+
[[deps.Logging]]
116+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
117+
118+
[[deps.MacroTools]]
119+
deps = ["Markdown", "Random"]
120+
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
121+
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
122+
version = "0.5.9"
123+
124+
[[deps.Markdown]]
125+
deps = ["Base64"]
126+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
127+
128+
[[deps.MbedTLS_jll]]
129+
deps = ["Artifacts", "Libdl"]
130+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
131+
132+
[[deps.Mmap]]
133+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
134+
135+
[[deps.MozillaCACerts_jll]]
136+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
137+
138+
[[deps.NetworkOptions]]
139+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
140+
141+
[[deps.OpenBLAS_jll]]
142+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
143+
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
144+
145+
[[deps.OrderedCollections]]
146+
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
147+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
148+
version = "1.4.1"
149+
150+
[[deps.Parsers]]
151+
deps = ["Dates"]
152+
git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413"
153+
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
154+
version = "2.3.1"
155+
156+
[[deps.Pkg]]
157+
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
158+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
159+
160+
[[deps.Printf]]
161+
deps = ["Unicode"]
162+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
163+
164+
[[deps.REPL]]
165+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
166+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
167+
168+
[[deps.Random]]
169+
deps = ["SHA", "Serialization"]
170+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
171+
172+
[[deps.SHA]]
173+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
174+
175+
[[deps.Serialization]]
176+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
177+
178+
[[deps.SharedArrays]]
179+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
180+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
181+
182+
[[deps.SimpleTraits]]
183+
deps = ["InteractiveUtils", "MacroTools"]
184+
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
185+
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
186+
version = "0.9.4"
187+
188+
[[deps.Sockets]]
189+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
190+
191+
[[deps.SparseArrays]]
192+
deps = ["LinearAlgebra", "Random"]
193+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
194+
195+
[[deps.StaticArrays]]
196+
deps = ["LinearAlgebra", "Random", "Statistics"]
197+
git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911"
198+
uuid = "90137ffa-7385-5640-81b9-e52037218182"
199+
version = "1.4.4"
200+
201+
[[deps.Statistics]]
202+
deps = ["LinearAlgebra", "SparseArrays"]
203+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
204+
205+
[[deps.TOML]]
206+
deps = ["Dates"]
207+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
208+
209+
[[deps.Tar]]
210+
deps = ["ArgTools", "SHA"]
211+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
212+
213+
[[deps.Test]]
214+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
215+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
216+
217+
[[deps.UUIDs]]
218+
deps = ["Random", "SHA"]
219+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
220+
221+
[[deps.Unicode]]
222+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
223+
224+
[[deps.Zlib_jll]]
225+
deps = ["Libdl"]
226+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
227+
228+
[[deps.libblastrampoline_jll]]
229+
deps = ["Artifacts", "Libdl", "OpenBLAS_jll"]
230+
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
231+
232+
[[deps.nghttp2_jll]]
233+
deps = ["Artifacts", "Libdl"]
234+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
235+
236+
[[deps.p7zip_jll]]
237+
deps = ["Artifacts", "Libdl"]
238+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

Diff for: docs/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
34

45
[compat]
56
Documenter = "~0.26.2"

Diff for: src/shortestpaths/astar.jl

+31-25
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
function reconstruct_path!(total_path, # a vector to be filled with the shortest path
77
came_from, # a vector holding the parent of each node in the A* exploration
88
end_idx, # the end vertex
9-
g) # the graph
10-
11-
E = edgetype(g)
9+
g, # the graph
10+
edgetype_to_return::Type{E}=edgetype(g)) where {E<:AbstractEdge}
1211
curr_idx = end_idx
1312
while came_from[curr_idx] != curr_idx
14-
pushfirst!(total_path, E(came_from[curr_idx], curr_idx))
13+
pushfirst!(total_path, edgetype_to_return(came_from[curr_idx], curr_idx))
1514
curr_idx = came_from[curr_idx]
1615
end
1716
end
@@ -21,19 +20,17 @@ function a_star_impl!(g, # the graph
2120
open_set, # an initialized heap containing the active vertices
2221
closed_set, # an (initialized) color-map to indicate status of vertices
2322
g_score, # a vector holding g scores for each node
24-
f_score, # a vector holding f scores for each node
2523
came_from, # a vector holding the parent of each node in the A* exploration
2624
distmx,
27-
heuristic)
28-
29-
E = edgetype(g)
30-
total_path = Vector{E}()
25+
heuristic,
26+
edgetype_to_return::Type{E}) where {E<:AbstractEdge}
27+
total_path = Vector{edgetype_to_return}()
3128

3229
@inbounds while !isempty(open_set)
3330
current = dequeue!(open_set)
3431

3532
if current == goal
36-
reconstruct_path!(total_path, came_from, current, g)
33+
reconstruct_path!(total_path, came_from, current, g, edgetype_to_return)
3734
return total_path
3835
end
3936

@@ -56,38 +53,47 @@ function a_star_impl!(g, # the graph
5653
end
5754

5855
"""
59-
a_star(g, s, t[, distmx][, heuristic])
56+
a_star(g, s, t[, distmx][, heuristic][, edgetype_to_return])
57+
58+
Compute a shortest path using the [A* search algorithm](http://en.wikipedia.org/wiki/A%2A_search_algorithm).
6059
61-
Return a vector of edges comprising the shortest path between vertices `s` and `t`
62-
using the [A* search algorithm](http://en.wikipedia.org/wiki/A%2A_search_algorithm).
63-
An optional heuristic function and edge distance matrix may be supplied. If missing,
64-
the distance matrix is set to [`Graphs.DefaultDistance`](@ref) and the heuristic is set to
65-
`n -> 0`.
60+
# Arguments
61+
- `g::AbstractGraph`: the graph
62+
- `s::Integer`: the source vertex
63+
- `t::Integer`: the target vertex
64+
- `distmx::AbstractMatrix`: an optional (possibly sparse) `n × n` matrix of edge weights. It is set to `weights(g)` by default (which itself falls back on [`Graphs.DefaultDistance`](@ref)).
65+
- `heuristic::Function`: an optional function mapping each vertex to a lower estimate of the remaining distance from `v` to `t`. It is set to `v -> 0` by default (which corresponds to Dijkstra's algorithm)
66+
- `edgetype_to_return::Type{E}`: the eltype `E<:AbstractEdge` of the vector of edges returned. It is set to `edgetype(g)` by default. Note that the two-argument constructor `E(u, v)` must be defined, even for weighted edges: if it isn't, consider using `E = Graphs.SimpleEdge`.
6667
"""
6768
function a_star(g::AbstractGraph{U}, # the g
6869
s::Integer, # the start vertex
6970
t::Integer, # the end vertex
7071
distmx::AbstractMatrix{T}=weights(g),
71-
heuristic::Function=n -> zero(T)) where {T, U}
72-
73-
E = Edge{eltype(g)}
74-
72+
heuristic::Function=n -> zero(T),
73+
edgetype_to_return::Type{E}=edgetype(g)) where {T, U, E<:AbstractEdge}
7574
# if we do checkbounds here, we can use @inbounds in a_star_impl!
7675
checkbounds(distmx, Base.OneTo(nv(g)), Base.OneTo(nv(g)))
7776

78-
open_set = PriorityQueue{Integer, T}()
77+
open_set = PriorityQueue{U, T}()
7978
enqueue!(open_set, s, 0)
8079

8180
closed_set = zeros(Bool, nv(g))
8281

8382
g_score = fill(Inf, nv(g))
8483
g_score[s] = 0
8584

86-
f_score = fill(Inf, nv(g))
87-
f_score[s] = heuristic(s)
88-
8985
came_from = fill(-one(s), nv(g))
9086
came_from[s] = s
9187

92-
a_star_impl!(g, t, open_set, closed_set, g_score, f_score, came_from, distmx, heuristic)
88+
a_star_impl!(
89+
g,
90+
t,
91+
open_set,
92+
closed_set,
93+
g_score,
94+
came_from,
95+
distmx,
96+
heuristic,
97+
edgetype_to_return
98+
)
9399
end

Diff for: test/shortestpaths/astar.jl

+7
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,11 @@
1515
g = complete_graph(4)
1616
w = float([1 1 1 4; 1 1 1 1; 1 1 1 1; 4 1 1 1])
1717
@test length(a_star(g, 1, 4, w)) == 2
18+
19+
# test for #120
20+
struct MyFavoriteEdgeType <: AbstractEdge{Int}
21+
s::Int
22+
d::Int
23+
end
24+
@test eltype(a_star(g, 1, 4, w, n -> 0, MyFavoriteEdgeType)) == MyFavoriteEdgeType
1825
end

0 commit comments

Comments
 (0)