diff --git a/Project.toml b/Project.toml index 6c4671f9..a24fe91c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,12 @@ name = "Metalhead" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.8.3" +version = "0.8.4" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" @@ -17,6 +17,12 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +MetalheadCUDAExt = "CUDA" + [compat] BSON = "0.3.2" CUDA = "4, 5" diff --git a/ext/MetalheadCUDAExt.jl b/ext/MetalheadCUDAExt.jl new file mode 100644 index 00000000..72fa126d --- /dev/null +++ b/ext/MetalheadCUDAExt.jl @@ -0,0 +1,18 @@ +module MetalheadCUDAExt + +if isdefined(Base, :get_extension) + using Metalhead: Metalhead +else + using ..Metalhead: Metalhead +end +using CUDA: CUDA, CuArray + +## bs is `clipped_block_size` +# Dispatch for GPU +Metalhead.Layers.dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = Metalhead.Layers._dropblock_mask(rng, x, gamma, bs) +function Metalhead.Layers.dropblock_mask(rng, x::CuArray, gamma, bs) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports + CUDA.RNG for CuArrays.")) +end + +end diff --git a/src/Metalhead.jl b/src/Metalhead.jl index d7db8d4c..ced4ca44 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -93,4 +93,8 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end +if !isdefined(Base, :get_extension) + include("../ext/MetalheadCUDAExt.jl") +end + end # module diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 7c5dc085..aaaa2022 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -2,7 +2,6 @@ module Layers using Flux using Flux: default_rng_value -using CUDA using NNlib using Functors using ChainRulesCore diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 15f8e753..8593f930 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -43,13 +43,6 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, return x .* block_mask .* normalize_scale end -## bs is `clipped_block_size` -# Dispatch for GPU -dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) -function dropblock_mask(rng, x::CuArray, gamma, bs) - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports - CUDA.RNG for CuArrays.")) -end # Dispatch for CPU dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) diff --git a/test/Project.toml b/test/Project.toml index e65c612a..b121720e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"