Skip to content

Commit

Permalink
use JLArrays, add gradient test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 24, 2022
1 parent 9e8a0b4 commit f69a180
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ Zygote = "0.6.40"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"]
test = ["CUDA", "JLArrays", "StaticArrays", "Test", "Zygote"]
24 changes: 11 additions & 13 deletions test/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
using Optimisers
using ChainRulesCore #, Functors, StaticArrays, Zygote
using LinearAlgebra, Statistics, Test
using ChainRulesCore, Zygote
using Test

import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
@info "starting CUDA tests"
else
@info "CUDA not functional, testing via GPUArrays"
using GPUArrays
GPUArrays.allowscalar(false)
@info "CUDA not functional, testing with JLArrays instead"
using JLArrays
JLArrays.allowscalar(false)

# GPUArrays provides a fake GPU array, for testing
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
using Random, Adapt # loaded within jl_file
include(jl_file)
using .JLArrays
cu = jl
cu = jl32
CuArray{T,N} = JLArray{T,N}
end

@test cu(rand(3)) .+ 1 isa CuArray

@testset "very basics" begin
Expand Down Expand Up @@ -89,6 +82,11 @@ end
v, re = destructure(m)
@test v isa CuArray
@test re(2v).x isa CuArray

dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1]
@test dm.z isa CuArray
dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1]
@test dv isa CuArray
end

@testset "destructure mixed" begin
Expand Down

0 comments on commit f69a180

Please sign in to comment.