Skip to content

Commit

Permalink
Merge pull request #253 from mohamed82008/mt/optional_cuda
Browse files Browse the repository at this point in the history
Make CUDA an optional dependency
  • Loading branch information
ToucheSir authored Oct 14, 2023
2 parents 6e92ad6 + dc4fb0a commit eb3f9a4
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 10 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name = "Metalhead"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
version = "0.8.3"
version = "0.8.4"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -17,6 +17,12 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
MetalheadCUDAExt = "CUDA"

[compat]
BSON = "0.3.2"
CUDA = "4, 5"
Expand Down
18 changes: 18 additions & 0 deletions ext/MetalheadCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module MetalheadCUDAExt

if isdefined(Base, :get_extension)
using Metalhead: Metalhead
else
using ..Metalhead: Metalhead
end
using CUDA: CUDA, CuArray

## bs is `clipped_block_size`
# Dispatch for GPU
Metalhead.Layers.dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = Metalhead.Layers._dropblock_mask(rng, x, gamma, bs)
function Metalhead.Layers.dropblock_mask(rng, x::CuArray, gamma, bs)
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports
CUDA.RNG for CuArrays."))
end

end
4 changes: 4 additions & 0 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,8 @@ for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

if !isdefined(Base, :get_extension)
include("../ext/MetalheadCUDAExt.jl")
end

end # module
1 change: 0 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Layers

using Flux
using Flux: default_rng_value
using CUDA
using NNlib
using Functors
using ChainRulesCore
Expand Down
7 changes: 0 additions & 7 deletions src/layers/drop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,6 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob,
return x .* block_mask .* normalize_scale
end

## bs is `clipped_block_size`
# Dispatch for GPU
dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs)
function dropblock_mask(rng, x::CuArray, gamma, bs)
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports
CUDA.RNG for CuArrays."))
end
# Dispatch for CPU
dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs)

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
Expand Down

0 comments on commit eb3f9a4

Please sign in to comment.