Skip to content

Fix loading of pooling layers #2598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2025
Merged

Conversation

adrhill
Copy link
Contributor

@adrhill adrhill commented Mar 31, 2025

As discussed on Slack, I tried my hand at fixing #2584 and FluxML/Metalhead.jl#287.
This line of code appears to be enough to make Metalhead's VGG models load.

Without it, the following error is thrown:

julia> using Metalhead

julia> VGG(19; pretrain=true)
ERROR: ArgumentError: Tried to load Base.OneTo(0) into (:pad, :k, :stride) but the structures do not match.
Stacktrace:
  [1] loadmodel!(dst::MaxPool{2, 4}, src::Tuple{}; filter::Function, cache::IdSet{Any})
    @ Flux ~/Developer/Flux.jl/src/loading.jl:104
  [2] loadmodel!(dst::Tuple{…}, src::Tuple{…}; filter::Function, cache::IdSet{…})
    @ Flux ~/Developer/Flux.jl/src/loading.jl:118
  [3] loadmodel!(dst::Chain{Tuple{…}}, src::@NamedTuple{layers::Tuple{…}}; filter::Function, cache::IdSet{Any})
    @ Flux ~/Developer/Flux.jl/src/loading.jl:118

I could use some guidance on whether this fix is adequate and how to test it.

@mcabbott
Copy link
Member

mcabbott commented Apr 2, 2025

To be clear the before-and-after states are

julia> Flux.state(MaxPool((2,3)))
()

(jl_gaQBh1) pkg> st Flux
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_gaQBh1/Project.toml`
⌃ [587475ba] Flux v0.14.25

and

julia> Flux.state(MaxPool((2,3)))
(k = (2, 3), pad = (0, 0, 0, 0), stride = (2, 3))

(@v1.11) pkg> st Flux
Status `~/.julia/environments/v1.11/Project.toml`
  [587475ba] Flux v0.16.3

So this looks fine?

@adrhill
Copy link
Contributor Author

adrhill commented Apr 13, 2025

For the sake of documentation:

Flux v0.14:

julia> Flux.state(MaxPool((5, 5)))
()

julia> Flux.state(MeanPool((5, 5)))
()

julia> Flux.state(AdaptiveMaxPool((5, 5)))
()

julia> Flux.state(AdaptiveMeanPool((5, 5)))
()

julia> Flux.state(GlobalMaxPool())
()

julia> Flux.state(GlobalMeanPool())
()

Flux v0.16:

julia> Flux.state(MaxPool((5, 5)))
(k = (5, 5), pad = (0, 0, 0, 0), stride = (5, 5))

julia> Flux.state(MeanPool((5, 5)))
(k = (5, 5), pad = (0, 0, 0, 0), stride = (5, 5))

julia> Flux.state(AdaptiveMaxPool((5, 5)))
(out = (5, 5),)

julia> Flux.state(AdaptiveMeanPool((5, 5)))
(out = (5, 5),)

julia> Flux.state(GlobalMaxPool())
()

julia> Flux.state(GlobalMeanPool())
()

So I guess the global pooling layers require no modification?

@adrhill adrhill changed the title Fix loading of MaxPool layers Fix loading of pooling layers Apr 13, 2025
@mcabbott mcabbott merged commit 0e36af9 into FluxML:master Apr 15, 2025
1 of 9 checks passed
@adrhill adrhill deleted the ah/load-pooling branch May 20, 2025 09:08
@adrhill
Copy link
Contributor Author

adrhill commented Jun 2, 2025

Could we get a patch release with this fix? I've got some students that will be using Metalhead very soon.

@mcabbott
Copy link
Member

mcabbott commented Jun 2, 2025

Done! once JuliaRegistries/General#132163 is merged

@remi-garcia
Copy link

A very similar error remains for ResNet18 network and Parallel/PartialFunction layers

@mcabbott
Copy link
Member

mcabbott commented Jun 2, 2025

Can you post the exact error? And ideally what Flux.state(the_layer) returns on Flux 0.14 and 0.16, as above?

@remi-garcia
Copy link

julia> model = ResNet(18; pretrain=true, nclasses=10)
ERROR: ArgumentError: Tried to load Base.OneTo(0) into (:args, :expr_string, :func, :kwargs) but the structures do not match.
Stacktrace:
  [1] loadmodel!(dst::Function, src::Tuple{}; filter::Function, cache::IdSet{Any})
    @ Flux ~/.julia/packages/Flux/6pNhw/src/loading.jl:95
  [2] loadmodel!(dst::Parallel{…}, src::@NamedTuple{}; filter::Function, cache::IdSet{…})
    @ Flux ~/.julia/packages/Flux/6pNhw/src/loading.jl:105

Flux 0.16

julia> using Flux, Metalhead, PartialFunctions

julia> Flux.state(Parallel(
               PartialFunctions.PartialFunction(
                   "",
                   Metalhead.addact,
                   (relu,),
                   NamedTuple(),
               ),
               Dense(5, 2)
           )
       )
(connection = (expr_string = "", func = (), args = ((),), kwargs = NamedTuple()), layers = ((weight = Float32[-0.82734597 0.431884  -0.47095233 0.34837782; -0.32114923 -0.80584735  -0.08924356 0.767007], bias = Float32[0.0, 0.0], σ = ()),))

Flux 0.14

julia> using Flux, Metalhead, PartialFunctions

julia> Flux.state(Parallel(
                      PartialFunctions.PartialFunction(
                          "",
                          Metalhead.addact,
                          (relu,),
                          NamedTuple(),
                      ),
                      Dense(5, 2)
                  )
              )
(connection = (), layers = ((weight = Float32[0.2765216 0.6709159  0.24647865 -0.076339304; -0.18728012 -0.83941054  0.27120978 0.22845758], bias = Float32[0.0, 0.0], σ = ()),))

@adrhill
Copy link
Contributor Author

adrhill commented Jun 3, 2025

I think this is largely a Metalhead issue, due to the recent introduction of PartialFunctions (see issues FluxML/Metalhead.jl#286, FluxML/Metalhead.jl#287).

As I've mentioned in FluxML/Metalhead.jl#286 (comment), it would be great if all layer types for common pre-trained architectures lived in Flux. Since Julia is heavily dispatch based, my packages otherwise end up requiring a strict dependency on Metalhead. This would also allow for bugs like FluxML/Metalhead.jl#287 to be caught in Flux' testing.

For this specific problem, it would be nice to have a simple AddAct layer type instead of having to dispatch on the awful looking PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}. type.

@mcabbott
Copy link
Member

mcabbott commented Jun 3, 2025

If I understand right, you want this to happen:

julia> using Metalhead: addact, relu

julia> addact(relu, [1,2,3], [4,-3,-2])
3-element Vector{Int64}:
 5
 0
 1

and the PartialFunction is very close to being Base.Fix1(addact, relu), except that this only accepts one argument for now... Julia 1.12 will change this:

julia> _addact(activation, xs...) = activation.(sum(xs));  # my version broadcasts, Metalhead does not

julia> addtanh = Base.Fix1(_addact, tanh)
(::Base.Fix1{typeof(_addact), typeof(tanh)}) (generic function with 2 methods)

julia> addtanh([1,2,3], [4,-3,-2])
3-element Vector{Float64}:
  0.9999092042625951
 -0.7615941559557649
  0.7615941559557649

julia> VERSION
v"1.12.0-beta3"

May I ask why not just use Base's ComposedFunction for this?

julia> addrelu = relu  (+);

julia> addrelu([1,2,3], [4,-3,-2])
3-element Vector{Int64}:
 5
 0
 1

Like Metalhead.addact, this relies on the auto-broadcasting relu(::AbstractArray) method, so won't work for tanh:

julia> Metalhead.addact(tanh, [1,2,3], [4,-3,-2])
ERROR: MethodError: no method matching tanh(::Vector{Int64})

julia> (tanh  (+))([1,2,3], [4,-3,-2])
ERROR: MethodError: no method matching tanh(::Vector{Int64})

This also appears to have the same state on old & new Flux:

julia> Flux.state(relu  (+))  # Flux v0.14.25, and also Flux v0.16.3
(outer = (), inner = ())

@adrhill
Copy link
Contributor Author

adrhill commented Jun 4, 2025

May I ask why not just use Base's ComposedFunction for this?

This is code from Metalhead.jl's ResNet implementation: FluxML/Metalhead.jl#287 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants