-
-
Notifications
You must be signed in to change notification settings - Fork 612
Fix loading of pooling layers #2598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
To be clear the before-and-after states are julia> Flux.state(MaxPool((2,3)))
()
(jl_gaQBh1) pkg> st Flux
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_gaQBh1/Project.toml`
⌃ [587475ba] Flux v0.14.25 and julia> Flux.state(MaxPool((2,3)))
(k = (2, 3), pad = (0, 0, 0, 0), stride = (2, 3))
(@v1.11) pkg> st Flux
Status `~/.julia/environments/v1.11/Project.toml`
[587475ba] Flux v0.16.3 So this looks fine?
|
For the sake of documentation: Flux julia> Flux.state(MaxPool((5, 5)))
()
julia> Flux.state(MeanPool((5, 5)))
()
julia> Flux.state(AdaptiveMaxPool((5, 5)))
()
julia> Flux.state(AdaptiveMeanPool((5, 5)))
()
julia> Flux.state(GlobalMaxPool())
()
julia> Flux.state(GlobalMeanPool())
() Flux julia> Flux.state(MaxPool((5, 5)))
(k = (5, 5), pad = (0, 0, 0, 0), stride = (5, 5))
julia> Flux.state(MeanPool((5, 5)))
(k = (5, 5), pad = (0, 0, 0, 0), stride = (5, 5))
julia> Flux.state(AdaptiveMaxPool((5, 5)))
(out = (5, 5),)
julia> Flux.state(AdaptiveMeanPool((5, 5)))
(out = (5, 5),)
julia> Flux.state(GlobalMaxPool())
()
julia> Flux.state(GlobalMeanPool())
() So I guess the global pooling layers require no modification? |
Could we get a patch release with this fix? I've got some students that will be using Metalhead very soon. |
Done! once JuliaRegistries/General#132163 is merged |
A very similar error remains for ResNet18 network and Parallel/PartialFunction layers |
Can you post the exact error? And ideally what |
julia> model = ResNet(18; pretrain=true, nclasses=10)
ERROR: ArgumentError: Tried to load Base.OneTo(0) into (:args, :expr_string, :func, :kwargs) but the structures do not match.
Stacktrace:
[1] loadmodel!(dst::Function, src::Tuple{}; filter::Function, cache::IdSet{Any})
@ Flux ~/.julia/packages/Flux/6pNhw/src/loading.jl:95
[2] loadmodel!(dst::Parallel{…}, src::@NamedTuple{…}; filter::Function, cache::IdSet{…})
@ Flux ~/.julia/packages/Flux/6pNhw/src/loading.jl:105 Flux 0.16 julia> using Flux, Metalhead, PartialFunctions
julia> Flux.state(Parallel(
PartialFunctions.PartialFunction(
"",
Metalhead.addact,
(relu,),
NamedTuple(),
),
Dense(5, 2)
)
)
(connection = (expr_string = "", func = (), args = ((),), kwargs = NamedTuple()), layers = ((weight = Float32[-0.82734597 0.431884 … -0.47095233 0.34837782; -0.32114923 -0.80584735 … -0.08924356 0.767007], bias = Float32[0.0, 0.0], σ = ()),)) Flux 0.14 julia> using Flux, Metalhead, PartialFunctions
julia> Flux.state(Parallel(
PartialFunctions.PartialFunction(
"",
Metalhead.addact,
(relu,),
NamedTuple(),
),
Dense(5, 2)
)
)
(connection = (), layers = ((weight = Float32[0.2765216 0.6709159 … 0.24647865 -0.076339304; -0.18728012 -0.83941054 … 0.27120978 0.22845758], bias = Float32[0.0, 0.0], σ = ()),)) |
I think this is largely a Metalhead issue, due to As I've mentioned in FluxML/Metalhead.jl#286 (comment), it would be great if all layer types for common pre-trained architectures lived in Flux. Since Julia is heavily dispatch based, my packages otherwise end up requiring a strict dependency on Metalhead. This would also allow for bugs like FluxML/Metalhead.jl#287 to be caught in Flux' testing. For this specific problem, it would be nice to have a simple |
If I understand right, you want this to happen: julia> using Metalhead: addact, relu
julia> addact(relu, [1,2,3], [4,-3,-2])
3-element Vector{Int64}:
5
0
1 and the julia> _addact(activation, xs...) = activation.(sum(xs)); # my version broadcasts, Metalhead does not
julia> addtanh = Base.Fix1(_addact, tanh)
(::Base.Fix1{typeof(_addact), typeof(tanh)}) (generic function with 2 methods)
julia> addtanh([1,2,3], [4,-3,-2])
3-element Vector{Float64}:
0.9999092042625951
-0.7615941559557649
0.7615941559557649
julia> VERSION
v"1.12.0-beta3" May I ask why not just use Base's ComposedFunction for this? julia> addrelu = relu ∘ (+);
julia> addrelu([1,2,3], [4,-3,-2])
3-element Vector{Int64}:
5
0
1 Like julia> Metalhead.addact(tanh, [1,2,3], [4,-3,-2])
ERROR: MethodError: no method matching tanh(::Vector{Int64})
julia> (tanh ∘ (+))([1,2,3], [4,-3,-2])
ERROR: MethodError: no method matching tanh(::Vector{Int64}) This also appears to have the same state on old & new Flux: julia> Flux.state(relu ∘ (+)) # Flux v0.14.25, and also Flux v0.16.3
(outer = (), inner = ()) |
This is code from Metalhead.jl's ResNet implementation: FluxML/Metalhead.jl#287 (comment) |
As discussed on Slack, I tried my hand at fixing #2584 and FluxML/Metalhead.jl#287.
This line of code appears to be enough to make Metalhead's
VGG
models load.Without it, the following error is thrown:
I could use some guidance on whether this fix is adequate and how to test it.