Skip to content

Commit 094dca6

Browse files
authored
NNlib on the GPU (#435)
* nnlib tweaks * Proper cuda tests * Depend on GPUArraysCore * Tweak NNlib implementation * Extra methods in CUDA * Make tuple_map always generated * Make tuple_map always generated * Remove redundant code * Tidy up build_tangent implementation * Add failing test cases * Tweak CUDA test file * Revert stuff added in a different PR * Tweaks to CUDA tests * Fix zero_rdata_from_type * Fix formatting * Make nnlib tests pass * Renable CuArray construction test * Make buildkite actually run nnlib tests * Fix display * Formatting * CI formatting * Revert tuple_map implementation change * Revert changes to tuple_map entirely * Bump patch version
1 parent 35b432c commit 094dca6

File tree

8 files changed

+151
-99
lines changed

8 files changed

+151
-99
lines changed

.buildkite/pipeline.yml

+10-5
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ env:
22
SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg=="
33

44
steps:
5-
- label: "Julia v{{matrix}}"
5+
- label: "Julia v{{matrix.version}}, {{matrix.label}}"
66
plugins:
77
- JuliaCI/julia#v1:
8-
version: "{{matrix}}"
8+
version: "{{matrix.version}}"
99
- JuliaCI/julia-coverage#v1:
1010
dirs:
1111
- src
@@ -17,8 +17,13 @@ steps:
1717
if: build.message !~ /\[skip tests\]/
1818
timeout_in_minutes: 60
1919
env:
20-
LABEL: cuda
20+
LABEL: "{{matrix.label}}"
2121
TEST_TYPE: ext
2222
matrix:
23-
- "1"
24-
- "1.10"
23+
setup:
24+
version:
25+
- "1"
26+
- "1.10"
27+
label:
28+
- "cuda"
29+
- "nnlib"

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.76"
4+
version = "0.4.77"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -11,6 +11,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1111
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
1212
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
1313
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
14+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1415
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1516
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -49,6 +50,7 @@ DiffRules = "1"
4950
DiffTests = "0.1"
5051
ExprTools = "0.1"
5152
FunctionWrappers = "1.1.3"
53+
GPUArraysCore = "0.1"
5254
Graphs = "1"
5355
InteractiveUtils = "1"
5456
JET = "0.9"

ext/MooncakeCUDAExt.jl

+17-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Mooncake:
1010
rrule!!,
1111
@is_primitive,
1212
tangent_type,
13+
tangent,
1314
zero_tangent,
1415
randn_tangent,
1516
increment!!,
@@ -20,7 +21,9 @@ import Mooncake:
2021
_scale,
2122
TestUtils,
2223
CoDual,
23-
NoPullback
24+
NoPullback,
25+
to_cr_tangent,
26+
increment_and_get_rdata!
2427

2528
import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate
2629

@@ -31,7 +34,11 @@ zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x)
3134
function randn_tangent(rng::AbstractRNG, x::CuArray{Float32})
3235
return cu(randn(rng, Float32, size(x)...))
3336
end
34-
TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y
37+
function TestUtils.has_equal_data_internal(
38+
x::P, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}
39+
) where {P<:CuArray{<:IEEEFloat}}
40+
return isapprox(x, y)
41+
end
3542
increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y
3643
__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true
3744
_set_to_zero!!(::Mooncake.IncCache, x::CuArray{<:IEEEFloat}) = x .= 0
@@ -52,6 +59,14 @@ function Mooncake._verify_fdata_value(p::CuArray, f::CuArray)
5259
end
5360
return nothing
5461
end
62+
tangent_type(::Type{P}, ::Type{NoRData}) where {P<:CuArray} = P
63+
tangent(p::CuArray, ::NoRData) = p
64+
65+
to_cr_tangent(x::CuArray{<:IEEEFloat}) = x
66+
function increment_and_get_rdata!(f::T, ::NoRData, t::T) where {T<:CuArray{<:IEEEFloat}}
67+
f .+= t
68+
return NoRData()
69+
end
5570

5671
# Basic rules for operating on CuArrays.
5772

ext/MooncakeNNlibExt.jl

+36-15
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,75 @@
11
module MooncakeNNlibExt
22

3-
using NNlib, Random, Mooncake
3+
using GPUArraysCore, NNlib, Random, Mooncake
44
using Base: IEEEFloat
55
using NNlib: dropout
66

77
using NNlib: conv, depthwiseconv
88
import Mooncake: @from_rrule, DefaultCtx, MinimalCtx
99

10+
# Array types which we test rules against, so are confident work.
11+
const SupportedArray{P,N} = Union{Array{P,N},AbstractGPUArray{P,N}}
12+
1013
@from_rrule(
11-
MinimalCtx, Tuple{typeof(batched_mul),Array{P,3},Array{P,3}} where {P<:IEEEFloat},
14+
MinimalCtx,
15+
Tuple{typeof(batched_mul),SupportedArray{P,3},SupportedArray{P,3}} where {P<:IEEEFloat},
1216
)
1317
@from_rrule(
14-
MinimalCtx, Tuple{typeof(dropout),AbstractRNG,Array{P},P} where {P<:IEEEFloat}, true,
18+
MinimalCtx,
19+
Tuple{typeof(dropout),AbstractRNG,SupportedArray{P},P} where {P<:IEEEFloat},
20+
true,
1521
)
16-
@from_rrule(MinimalCtx, Tuple{typeof(softmax),Array{<:IEEEFloat}}, true)
17-
@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax),Array{<:IEEEFloat}}, true)
18-
@from_rrule(MinimalCtx, Tuple{typeof(logsumexp),Array{<:IEEEFloat}}, true)
22+
@from_rrule(MinimalCtx, Tuple{typeof(softmax),SupportedArray{<:IEEEFloat}}, true)
23+
@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax),SupportedArray{<:IEEEFloat}}, true)
24+
@from_rrule(MinimalCtx, Tuple{typeof(logsumexp),SupportedArray{<:IEEEFloat}}, true)
1925
@from_rrule(
20-
MinimalCtx, Tuple{typeof(upsample_nearest),Array{<:IEEEFloat},NTuple{N,Int} where {N}},
26+
MinimalCtx,
27+
Tuple{typeof(upsample_nearest),SupportedArray{<:IEEEFloat},NTuple{N,Int} where {N}},
2128
)
2229
@from_rrule(
2330
MinimalCtx,
24-
Tuple{typeof(NNlib.fold),Array{<:IEEEFloat},NTuple{N,Int} where {N},DenseConvDims},
31+
Tuple{
32+
typeof(NNlib.fold),SupportedArray{<:IEEEFloat},NTuple{N,Int} where {N},DenseConvDims
33+
},
2534
)
26-
@from_rrule(MinimalCtx, Tuple{typeof(NNlib.unfold),Array{<:IEEEFloat},DenseConvDims})
2735
@from_rrule(
28-
MinimalCtx, Tuple{typeof(NNlib.scatter),Any,Array,Array{<:Union{Integer,Tuple}}}, true,
36+
MinimalCtx, Tuple{typeof(NNlib.unfold),SupportedArray{<:IEEEFloat},DenseConvDims}
37+
)
38+
@from_rrule(
39+
MinimalCtx,
40+
Tuple{typeof(NNlib.scatter),Any,SupportedArray,SupportedArray{<:Union{Integer,Tuple}}},
41+
true,
2942
)
3043
for conv in [:conv, :depthwiseconv]
3144
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
3245

3346
@eval @from_rrule(
3447
MinimalCtx,
35-
Tuple{typeof($conv),Array{P},Array{P},ConvDims} where {P<:IEEEFloat},
48+
Tuple{
49+
typeof($conv),SupportedArray{P},SupportedArray{P},ConvDims
50+
} where {P<:IEEEFloat},
3651
true,
3752
)
3853
@eval @from_rrule(
3954
MinimalCtx,
40-
Tuple{typeof($∇conv_data),Array{P},Array{P},ConvDims} where {P<:IEEEFloat},
55+
Tuple{
56+
typeof($∇conv_data),SupportedArray{P},SupportedArray{P},ConvDims
57+
} where {P<:IEEEFloat},
4158
true,
4259
)
4360
end
4461
@from_rrule(
4562
MinimalCtx,
46-
Tuple{typeof(∇conv_filter),Array{P},Array{P},ConvDims} where {P<:IEEEFloat},
63+
Tuple{
64+
typeof(∇conv_filter),SupportedArray{P},SupportedArray{P},ConvDims
65+
} where {P<:IEEEFloat},
4766
true,
4867
)
4968
for pool in [:maxpool, :meanpool]
50-
@eval @from_rrule(MinimalCtx, Tuple{typeof($pool),Array{<:IEEEFloat},PoolDims}, true)
69+
@eval @from_rrule(
70+
MinimalCtx, Tuple{typeof($pool),SupportedArray{<:IEEEFloat},PoolDims}, true
71+
)
5172
end
52-
@from_rrule(MinimalCtx, Tuple{typeof(pad_constant),Array,Any,Any}, true)
73+
@from_rrule(MinimalCtx, Tuple{typeof(pad_constant),SupportedArray,Any,Any}, true)
5374

5475
end

test/ext/cuda/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
23
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
45
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

test/ext/cuda/cuda.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@ using Pkg
22
Pkg.activate(@__DIR__)
33
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))
44

5-
using CUDA, JET, Mooncake, StableRNGs, Test
6-
using Mooncake.TestUtils: test_tangent, test_rule
5+
using AllocCheck, CUDA, JET, Mooncake, StableRNGs, Test
6+
using Mooncake.TestUtils: test_tangent_consistency, test_fwds_rvs_data, test_rule
77

88
@testset "cuda" begin
99

1010
# Check we can operate on CuArrays.
11-
test_tangent(
12-
StableRNG(123456),
13-
CuArray{Float32,2,CUDA.DeviceMemory}(undef, 8, 8);
14-
interface_only=false,
15-
)
11+
p = CuArray{Float32,2,CUDA.DeviceMemory}(undef, 8, 8)
12+
test_tangent_consistency(StableRNG(123456), p; interface_only=false)
13+
test_fwds_rvs_data(StableRNG(123456), p)
1614

1715
# Check we can instantiate a CuArray.
1816
test_rule(

test/ext/nnlib/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
23
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
34
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
45
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
56
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

0 commit comments

Comments
 (0)