Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test with GPUArrays #71

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ Zygote = "0.6.40"
julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["CUDA", "JLArrays", "StaticArrays", "Test", "Zygote"]
102 changes: 102 additions & 0 deletions test/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Optimisers
using ChainRulesCore, Zygote
using Test

import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
CUDA.allowscalar(false)
else
@info "CUDA not functional, testing with JLArrays instead"
using JLArrays
JLArrays.allowscalar(false)

cu = jl
CuArray{T,N} = JLArray{T,N}
end
@test cu(rand(3)) .+ 1 isa CuArray

@testset "very basics" begin
m = (cu([1.0, 2.0]),)
mid = objectid(m[1])
g = (cu([25, 33]),)
o = Descent(0.1f0)
s = Optimisers.setup(o, m)

s2, m2 = Optimisers.update(s, m, g)
@test Array(m[1]) == 1:2 # not mutated
@test m2[1] isa CuArray
@test Array(m2[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6

s3, m3 = Optimisers.update!(s, m, g)
@test objectid(m3[1]) == mid
@test Array(m3[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6

g4 = Tangent{typeof(m)}(g...)
s4, m4 = Optimisers.update!(s, (cu([1.0, 2.0]),), g4)
@test Array(m4[1]) ≈ [1,2] .- 0.1 .* [25, 33] atol=1e-6
end

@testset "basic mixed" begin
# Works trivially as every element of the tree is either here or there
m = (device = cu([1.0, 2.0]), host = [3.0, 4.0], neither = (5, 6, sin))
s = Optimisers.setup(ADAM(0.1), m)
@test s.device.state[1] isa CuArray
@test s.host.state[1] isa Array

g = (device = cu([1, 0.1]), host = [1, 10], neither = nothing)
s2, m2 = Optimisers.update(s, m, g)

@test m2.device isa CuArray
@test Array(m2.device) ≈ [0.9, 1.9] atol=1e-6

@test m2.host isa Array
@test m2.host ≈ [2.9, 3.9]
end

RULES = [
# Just a selection:
Descent(), ADAM(), RMSProp(), NADAM(),
# A few chained combinations:
OptimiserChain(WeightDecay(), ADAM(0.001)),
OptimiserChain(ClipNorm(), ADAM(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
]

name(o) = typeof(o).name.name # just for printing testset headings
name(o::OptimiserChain) = join(name.(o.opts), " → ")

@testset "rules: simple sum" begin
@testset "$(name(o))" for o in RULES
m = cu(shuffle!(reshape(1:64, 8, 8) .+ 0.0))
s = Optimisers.setup(o, m)
for _ in 1:10
g = Zygote.gradient(x -> sum(abs2, x + x'), m)[1]
s, m = Optimisers.update!(s, m, g)
end
@test sum(m) < sum(1:64)
end
end

@testset "destructure GPU" begin
m = (x = cu(Float32[1,2,3]), y = (0, 99), z = cu(Float32[4,5]))
v, re = destructure(m)
@test v isa CuArray
@test re(2v).x isa CuArray

dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1]
@test dm.z isa CuArray
dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1]
@test dv isa CuArray
end

@testset "destructure mixed" begin
# Not sure what should happen here!
m_c1 = (x = cu(Float32[1,2,3]), y = Float32[4,5])
v, re = destructure(m_c1)
@test re(2v).x isa CuArray
@test_broken re(2v).y isa Array

m_c2 = (x = Float32[1,2,3], y = cu(Float32[4,5]))
@test_skip destructure(m_c2) # ERROR: Scalar indexing
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
@testset verbose=true "GPU" begin
include("gpuarrays.jl")
end
end