Skip to content

Commit c0bfc26

Browse files
authored
Cleanup & remove some deps (#1560)
1 parent 2ec59a9 commit c0bfc26

File tree

10 files changed

+33
-74
lines changed

10 files changed

+33
-74
lines changed

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1010
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1212
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
13-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1413
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1514
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
1615
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -20,23 +19,25 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2019
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2120
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2221
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
23-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2422
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2523
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2624
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2725
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2826

2927
[weakdeps]
28+
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
3029
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
3130
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3231
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3332

3433
[extensions]
34+
ZygoteAtomExt = "Atom"
3535
ZygoteColorsExt = "Colors"
3636
ZygoteDistancesExt = "Distances"
3737
ZygoteTrackerExt = "Tracker"
3838

3939
[compat]
40+
Atom = "0.12"
4041
AbstractFFTs = "1.3.1"
4142
ChainRules = "1.72.2"
4243
ChainRulesCore = "1.25.1"
@@ -45,14 +46,12 @@ DiffRules = "1.4"
4546
Distances = "0.10"
4647
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
4748
ForwardDiff = "0.10"
48-
GPUArrays = "8.4.2, 9, 10, 11"
4949
GPUArraysCore = "0.1.1, 0.2"
5050
IRTools = "0.4.12"
5151
LogExpFunctions = "0.3.1"
5252
MacroTools = "0.5"
5353
NaNMath = "0.3, 1"
5454
PrecompileTools = "1"
55-
Requires = "1.1"
5655
SpecialFunctions = "1.6, 2"
5756
Statistics = "1"
5857
Tracker = "0.2"

ext/ZygoteAtomExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module ZygoteAtomExt
2+
3+
using Atom
4+
using Zygote
5+
using Zygote.Profile
6+
7+
Zygote.Profile.atom_expandpath(path::String) = Atom.expandpath(path)
8+
Zygote.Profile.juno(node::Zygote.Profile.Node) = Atom.msg("profile", Zygote.Profile.tojson(node))
9+
10+
end

ext/ZygoteColorsExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
module ZygoteColorsExt
22

3-
if isdefined(Base, :get_extension)
4-
using Zygote
5-
using Colors
6-
else
7-
using ..Zygote
8-
using ..Colors
9-
end
3+
using Zygote
4+
using Colors
105

116
Zygote.@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
127

ext/ZygoteDistancesExt.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
module ZygoteDistancesExt
22

3-
if isdefined(Base, :get_extension)
4-
using Zygote
5-
using Distances
6-
using LinearAlgebra
7-
else
8-
using ..Zygote
9-
using ..Distances
10-
using ..LinearAlgebra
11-
end
3+
using Zygote
4+
using Distances
5+
using LinearAlgebra
126

137
using Zygote: @adjoint, AContext, _pullback
148

ext/ZygoteTrackerExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
module ZygoteTrackerExt
22

3-
if isdefined(Base, :get_extension)
4-
using Zygote
5-
using Tracker: Tracker, TrackedArray, TrackedReal
6-
else
7-
using ..Zygote
8-
using ..Tracker: Tracker, TrackedArray, TrackedReal
9-
end
3+
using Zygote
4+
using Tracker: Tracker, TrackedArray, TrackedReal
105

116
Zygote.unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)
127

src/Zygote.jl

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
1010
using ChainRulesCore
1111
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
1212
using IRTools
13-
using MacroTools, Requires
13+
using MacroTools
1414
using MacroTools: @forward
1515

1616
import Distributed: pmap, CachingPool, workers
@@ -53,22 +53,7 @@ include("compiler/interface2.jl")
5353

5454
include("profiler/Profile.jl")
5555

56-
57-
if !isdefined(Base, :get_extension)
58-
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("../ext/ZygoteDistancesExt.jl")
59-
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/ZygoteTrackerExt.jl")
60-
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" include("../ext/ZygoteColorsExt.jl")
61-
end
62-
6356
using InteractiveUtils
64-
precompile() = Requires.@include("precompile.jl")
65-
66-
# helps to work around 265-y issues
67-
function refresh()
68-
Requires.@include("compiler/interface2.jl")
69-
precompile()
70-
return
71-
end
7257

7358
macro profile(ex)
7459
@capture(ex, f_(x__)) || error("@profile f(args...)")
@@ -79,6 +64,8 @@ macro profile(ex)
7964
end
8065

8166
using PrecompileTools
82-
@compile_workload precompile()
67+
@compile_workload begin
68+
include("precompile.jl")
69+
end
8370

84-
end # module
71+
end

src/compiler/interface2.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ function _generate_callable_pullback(j::Type{<:Pullback{T}}, world, Δ) where T
6262
return update!(meta.code, back)
6363
end
6464

65-
if VERSION >= v"1.10.0-DEV.873"
66-
6765
# on Julia 1.10, generated functions need to keep track of the world age
6866

6967
function _pullback_generator(world::UInt, source, self, ctx, f, args)
@@ -103,15 +101,3 @@ end
103101
$(Expr(:meta, :generated, _callable_pullback_generator))
104102
$(Expr(:meta, :generated_only))
105103
end
106-
107-
else
108-
109-
@generated function _pullback(ctx::AContext, f, args...)
110-
_generate_pullback(ctx, nothing, f, args...)
111-
end
112-
113-
@generated function (j::Pullback)(Δ)
114-
_generate_callable_pullback(j, nothing, Δ)
115-
end
116-
117-
end

src/profiler/Profile.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Profile
22

3-
using Requires
43
using ..Zygote: Pullback, meta, stacklines
54

65
function loc(f)
@@ -81,17 +80,18 @@ function profile(x)
8180
Node(Symbol(""), "", -1, cs)
8281
end
8382

84-
@init @require Atom="c52e3926-4ff0-5f6e-af25-54175e0327b1" begin
85-
function tojson(n::Node)
86-
name, path = Atom.expandpath(string(n.file))
83+
# Defined in the ZygoteAtomExt.
84+
function atom_expandpath end
85+
function juno end
86+
87+
function tojson(n::Node)
88+
name, path = atom_expandpath(string(n.file))
8789
Dict(:path => path,
8890
:location => name,
8991
:func => string(n.func),
9092
:line => n.line,
9193
:count => n.size,
9294
:children => map(tojson, n.children))
93-
end
94-
juno(n::Node) = Atom.msg("profile", tojson(n))
9595
end
9696

9797
end

test/chainrules_tests.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,8 @@ end
405405

406406
# To test the generic case, we need a struct within a struct.
407407
nested = Tangent{Base.RefValue{ComplexF64}}(; x=Tangent{ComplexF64}(; re=1, im=NoTangent()),)
408-
if VERSION > v"1.7-"
409-
@test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
410-
@test @inferred(Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im))) === NoTangent()
411-
else
412-
@test Zygote.z2d((; x=(; re=1)), Ref(3.0+im)) == nested
413-
@test Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im)) === NoTangent()
414-
end
408+
@test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
409+
@test @inferred(Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im))) === NoTangent()
415410

416411
x = (c = (a = randn(3,3), b = rand(3)), d = randn(5))
417412
z2d_compiled = Zygote.z2d(x, x)

test/compiler_tests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ bt = try back(1) catch e stacktrace(catch_backtrace()) end
4040

4141
# Type inference checks
4242

43-
Zygote.refresh()
44-
4543
y, back = @test_inferred pullback(*, 2, 3)
4644
@test_inferred(back(1))
4745

0 commit comments

Comments
 (0)