Skip to content

Commit 2ec59a9

Browse files
authored
Distribute tests across workers
2 parents 1502341 + ad59c8d commit 2ec59a9

25 files changed

+2515
-2430
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
test:
1414
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
1515
runs-on: ${{ matrix.os }}
16+
continue-on-error: ${{ matrix.version == 'nightly' }} # allow nightly to fail
1617
strategy:
1718
fail-fast: false
1819
matrix:
@@ -41,12 +42,9 @@ jobs:
4142
- uses: julia-actions/julia-buildpkg@v1
4243
env:
4344
JULIA_PKG_SERVER: ""
44-
# `allow-failure` not available yet https://github.com/actions/toolkit/issues/399
45-
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
4645
- uses: julia-actions/julia-runtest@v1
4746
env:
4847
JULIA_PKG_SERVER: ""
49-
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
5048
- uses: julia-actions/julia-processcoverage@v1
5149
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
5250
- uses: codecov/codecov-action@v3

Project.toml

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ ZygoteTrackerExt = "Tracker"
4040
AbstractFFTs = "1.3.1"
4141
ChainRules = "1.72.2"
4242
ChainRulesCore = "1.25.1"
43-
ChainRulesTestUtils = "1"
4443
Colors = "0.12, 0.13"
4544
DiffRules = "1.4"
4645
Distances = "0.10"
@@ -59,18 +58,3 @@ Statistics = "1"
5958
Tracker = "0.2"
6059
ZygoteRules = "0.2.7"
6160
julia = "1.10"
62-
63-
[extras]
64-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
65-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
66-
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
67-
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
68-
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
69-
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
70-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
71-
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
72-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
73-
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
74-
75-
[targets]
76-
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PythonCall", "Test"]

test/Project.toml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[deps]
2+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
3+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4+
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
10+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
11+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
12+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
13+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
14+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
15+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
17+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
18+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
19+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
20+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
21+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
22+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"

test/chainrules.jl renamed to test/chainrules_tests.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
using ChainRulesCore, ChainRulesTestUtils, Zygote
1+
@testitem "chainrules" begin
2+
3+
using ChainRulesCore
4+
using ChainRulesTestUtils
25
using Zygote: ZygoteRuleConfig
6+
using LinearAlgebra
37

48
@testset "ChainRules integration" begin
59
@testset "ChainRules basics" begin
@@ -64,7 +68,7 @@ using Zygote: ZygoteRuleConfig
6468
end
6569
return simo(x), simo_pullback
6670
end
67-
71+
6872
simo_outer(x) = sum(simo(x))
6973

7074
simo_rrule_hitcount[] = 0
@@ -86,7 +90,7 @@ using Zygote: ZygoteRuleConfig
8690
end
8791
return miso(a, b), miso_pullback
8892
end
89-
93+
9094

9195
miso_outer(x) = miso(100x, 10x)
9296

@@ -182,7 +186,7 @@ using Zygote: ZygoteRuleConfig
182186
end
183187
return kwfoo(x; k=k), kwfoo_pullback
184188
end
185-
189+
186190

187191
kwfoo_outer_unused(x) = kwfoo(x)
188192
kwfoo_outer_used(x) = kwfoo(x; k=-15)
@@ -207,7 +211,7 @@ using Zygote: ZygoteRuleConfig
207211
end
208212
return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback
209213
end
210-
214+
211215

212216
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
213217
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
@@ -218,7 +222,7 @@ using Zygote: ZygoteRuleConfig
218222
x::T
219223
end
220224
StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0)
221-
225+
222226
function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules})
223227
# notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
224228
# and also because apparently people actually want to do this. Weird, but 🤷
@@ -253,7 +257,7 @@ using Zygote: ZygoteRuleConfig
253257
@test ([1.0],) == Zygote.gradient(oout_id_outer, [π])
254258
@test oout_id_rrule_hitcount[] == 0
255259

256-
# Now try opting out After we have already used it
260+
# Now try opting out After we have already used it
257261
@opt_out ChainRulesCore.rrule(::typeof(oout_id), x::Real)
258262
oout_id_rrule_hitcount[] = 0
259263
@test (1.0,) == Zygote.gradient(oout_id_outer, π)
@@ -399,7 +403,7 @@ end
399403
@test @inferred(Zygote.z2d((nothing,), (1,))) === NoTangent()
400404
@test @inferred(Zygote.z2d((nothing, nothing), (1,2))) === NoTangent()
401405

402-
# To test the generic case, we need a struct within a struct.
406+
# To test the generic case, we need a struct within a struct.
403407
nested = Tangent{Base.RefValue{ComplexF64}}(; x=Tangent{ComplexF64}(; re=1, im=NoTangent()),)
404408
if VERSION > v"1.7-"
405409
@test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
@@ -456,3 +460,5 @@ end
456460
@test g[1] isa NamedTuple
457461
@test g[1].w isa Array
458462
end
463+
464+
end

test/compiler.jl renamed to test/compiler_tests.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using Zygote, Test
1+
@testitem "compiler" begin
2+
3+
using LinearAlgebra
24
using Zygote: pullback, @adjoint, Context
35

46
macro test_inferred(ex)
@@ -11,9 +13,10 @@ macro test_inferred(ex)
1113
end) |> esc
1214
end
1315

14-
trace_contains(st, func, file, line) = any(st) do fr
15-
func in (nothing, fr.func) && endswith(String(fr.file), file) &&
16-
fr.line == line
16+
function trace_contains(st, func, file, line)
17+
any(st) do fr
18+
func in (nothing, fr.func) && endswith(String(fr.file), file) && fr.line == line
19+
end
1720
end
1821

1922
bad(x) = x
@@ -32,8 +35,8 @@ y, back = pullback(badly, 2)
3235
@test_throws Exception back(1)
3336
bt = try back(1) catch e stacktrace(catch_backtrace()) end
3437

35-
@test trace_contains(bt, nothing, "compiler.jl", bad_def_line)
36-
@test trace_contains(bt, :badly, "compiler.jl", bad_call_line)
38+
@test trace_contains(bt, nothing, "compiler_tests.jl", bad_def_line)
39+
@test trace_contains(bt, :badly, "compiler_tests.jl", bad_call_line)
3740

3841
# Type inference checks
3942

@@ -277,20 +280,20 @@ function try_catch_finally(cond, x)
277280
x
278281
end
279282

280-
function try_catch_else(cond, x)
281-
x = 2x
283+
function try_catch_else(cond, x)
284+
x = 2x
282285

283-
try
284-
x = 2x
285-
cond && throw(nothing)
286-
catch
287-
x = 3x
288-
else
289-
x = 2x
290-
end
286+
try
287+
x = 2x
288+
cond && throw(nothing)
289+
catch
290+
x = 3x
291+
else
292+
x = 2x
293+
end
291294

292-
x
293-
end
295+
x
296+
end
294297

295298
@testset "try/catch" begin
296299
@testset "happy path (nothrow)" begin
@@ -337,3 +340,5 @@ end
337340
err = try pull(1.) catch ex; ex end
338341
@test occursin("Can't differentiate function execution in catch block", string(err))
339342
end
343+
344+
end

test/complex.jl renamed to test/complex_tests.jl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
using Zygote, Test, LinearAlgebra
1+
@testitem "complex" begin
22

3-
@testset "basic" begin
4-
5-
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] 1
6-
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] 0
7-
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] -1im
8-
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] 0 # projected to zero
9-
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] 1im
10-
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] 0
3+
using LinearAlgebra
114

12-
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
13-
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
14-
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im
15-
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
16-
@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
5+
@testset "basic" begin
6+
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] 1
7+
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] 0
8+
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] -1im
9+
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] 0 # projected to zero
10+
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] 1im
11+
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] 0
1712

18-
@test gradient(x -> norm((im*x) ./ (im)), 2)[1] == 1
19-
@test gradient(x -> norm((im) ./ (im*x)), 2)[1] == -1/4
20-
@test gradient(x -> real(det(x)), [1 2im; 3im 4])[1] [4 3im; 2im 1]
21-
@test gradient(x -> real(logdet(x)), [1 2im; 3im 4])[1] [4 3im; 2im 1]/10
22-
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] [4 3im; 2im 1]/10
13+
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
14+
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
15+
@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im
16+
@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
17+
@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im
2318

24-
# https://github.com/FluxML/Zygote.jl/issues/705
25-
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
26-
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)
19+
@test gradient(x -> norm((im*x) ./ (im)), 2)[1] == 1
20+
@test gradient(x -> norm((im) ./ (im*x)), 2)[1] == -1/4
21+
@test gradient(x -> real(det(x)), [1 2im; 3im 4])[1] [4 3im; 2im 1]
22+
@test gradient(x -> real(logdet(x)), [1 2im; 3im 4])[1] [4 3im; 2im 1]/10
23+
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] [4 3im; 2im 1]/10
2724

28-
end # @testset
25+
# https://github.com/FluxML/Zygote.jl/issues/705
26+
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
27+
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)
28+
end
2929

3030
fs_C_to_R = (real,
3131
imag,
@@ -120,3 +120,5 @@ end
120120
end
121121
@test Zygote.hessian(fun, collect(1:9)) [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
122122
end
123+
124+
end

test/cuda.jl renamed to test/cuda_tests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
@testitem "cuda" begin
2+
3+
using LinearAlgebra
14
using CUDA
25
using Zygote: Grads
36
using Random: randn!
47
import FiniteDifferences
8+
59
CUDA.allowscalar(false)
610

711
function gradcheck_gpu(f, xs...)
@@ -11,7 +15,6 @@ function gradcheck_gpu(f, xs...)
1115
return all(isapprox.(collect.(grad_zygote), grad_finite_difference))
1216
end
1317

14-
1518
# Test GPU movement inside the call to `gradient`
1619
@testset "GPU movement" begin
1720
r = rand(Float32, 3,3)
@@ -195,3 +198,5 @@ end
195198

196199

197200
end
201+
202+
end

test/deprecated.jl renamed to test/deprecated_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@testitem "deprecated" begin
2+
13
@test_deprecated dropgrad(1)
24
@test_deprecated ignore(1)
35
@test_deprecated Zygote.@ignore x=1
@@ -8,3 +10,5 @@
810
y = Zygote.@ignore x
911
x * y
1012
end == (1,)
13+
14+
end

test/features.jl renamed to test/features_tests.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@testitem "features" begin
2+
13
using Zygote, Test, LinearAlgebra
24
using Zygote: Params, gradient, forwarddiff
35
using FillArrays: Fill
@@ -397,7 +399,7 @@ global_r = 1
397399
@test back(1) == (nothing, 3)
398400
ref = first(keys(Zygote.cache(cx)))
399401
@test ref isa GlobalRef
400-
@test ref.mod == Main
402+
@test_broken ref.mod == Main # TODO module is now "features###"
401403
@test ref.name == :global_param
402404
@test Zygote.cache(cx)[ref] == 2
403405

@@ -409,7 +411,7 @@ global_r = 1
409411
end
410412
return global_r
411413
end
412-
414+
413415
@test gradient(pow_global, 2, 3) == (12, nothing)
414416
end
415417

@@ -694,14 +696,6 @@ end
694696
end == ([8 112; 36 2004],)
695697
end
696698

697-
@testset "PythonCall custom @adjoint" begin
698-
using PythonCall: pyimport, pyconvert
699-
math = pyimport("math")
700-
pysin(x) = math.sin(x)
701-
Zygote.@adjoint pysin(x) = pyconvert(Float64, math.sin(x)), δ -> (pyconvert(Float64, δ * math.cos(x)),)
702-
@test Zygote.gradient(pysin, 1.5) == Zygote.gradient(sin, 1.5)
703-
end
704-
705699
# https://github.com/JuliaDiff/ChainRules.jl/issues/257
706700
@testset "Keyword Argument Passing" begin
707701
struct Type1{VJP}
@@ -717,9 +711,8 @@ end
717711
end
718712

719713
i = 1
720-
global x = Any[nothing,nothing]
721-
722-
Zygote.@nograd g(x,i,sensealg) = Main.x[i] = sensealg
714+
x = Any[nothing,nothing]
715+
Zygote.@nograd g(x,i,sensealg) = x[i] = sensealg
723716
function f(;sensealg=nothing)
724717
g(x,i,sensealg)
725718
return rand(100)
@@ -889,3 +882,4 @@ end
889882
@test g1[1] g2[1] g3[1]
890883
end
891884

885+
end

test/forward/forward.jl renamed to test/forward_tests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using Zygote, Test
1+
@testitem "forward" begin
2+
3+
using LinearAlgebra
24

35
D(f, x) = pushforward(f, x)(1)
46

@@ -46,3 +48,5 @@ end == 0
4648
mul!(B, A, A)
4749
sum(B)
4850
end == 6
51+
52+
end

0 commit comments

Comments
 (0)