Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preliminary Metalhead.jl integration #208

Merged
merged 9 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -18,15 +19,15 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
CategoricalArrays = "0.10"
ColorTypes = "0.10.3, 0.11"
ComputationalResources = "0.3.2"
Flux = "0.10.4, 0.11, 0.12, 0.13"
Flux = "0.13"
Metalhead = "0.7"
MLJModelInterface = "1.1.1"
ProgressMeter = "1.7.1"
Tables = "1.0"
julia = "1.6"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -35,4 +36,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLDatasets", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
test = ["LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
4 changes: 4 additions & 0 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ using Statistics
using ColorTypes
using ComputationalResources
using Random
import Metalhead

include("utilities.jl")
const MMI=MLJModelInterface

include("penalizers.jl")
include("core.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
include("regressor.jl")
include("classifier.jl")
Expand All @@ -27,6 +30,7 @@ include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, ImageClassifier
export CUDALibs, CPU1



Expand Down
8 changes: 5 additions & 3 deletions src/builders.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE
# # BUILDING CHAINS A FROM HYPERPARAMETERS + INPUT/OUTPUT SHAPE

# We introduce chain builders as a way of exposing neural network
# hyperparameters (describing, architecture, dropout rates, etc) to
Expand All @@ -9,7 +9,7 @@
# input/output dimensions/shape.

# Below n or (n1, n2) etc refers to network inputs, while m or (m1,
# m2) etc refers to outputs.
# m2) etc refers to outputs.

abstract type Builder <: MLJModelInterface.MLJType end

Expand Down Expand Up @@ -38,7 +38,7 @@ using `n_hidden` nodes in the hidden layer and the specified `dropout`
(defaulting to 0.5). An activation function `σ` is applied between the
hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
is the geometric mean of the number of input and output nodes. The
number of input and output nodes is determined from the data.
number of input and output nodes is determined from the data.

The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
Expand Down Expand Up @@ -96,6 +96,8 @@ function MLJFlux.build(mlp::MLP, rng, n_in, n_out)
end


# # BUILER MACRO

struct GenericBuilder{F} <: Builder
apply::F
end
Expand Down
28 changes: 3 additions & 25 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,8 @@
## EXPOSE OPTIMISERS TO MLJ (for eg, tuning)

# Here we make the optimiser structs "transparent" so that their
# field values are exposed by calls to MLJ.params

for opt in (:Descent,
:Momentum,
:Nesterov,
:RMSProp,
:ADAM,
:RADAM,
:AdaMax,
:OADAM,
:ADAGrad,
:ADADelta,
:AMSGrad,
:NADAM,
:AdaBelief,
:Optimiser,
:InvDecay, :ExpDecay, :WeightDecay,
:ClipValue,
:ClipNorm) # last updated: Flux.jl 0.12.3

@eval begin
MLJModelInterface.istransparent(m::Flux.$opt) = true
end
end
# make the optimiser structs "transparent" so that their field values
# are exposed by calls to MLJ.params:
MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true


## GENERAL METHOD TO OPTIMIZE A CHAIN
Expand Down
146 changes: 146 additions & 0 deletions src/metalhead.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#=

TODO: After https://github.com/FluxML/Metalhead.jl/issues/176:

- Export and externally document `image_builder` method

- Delete definition of `ResNetHack` below

- Change default builder in ImageClassifier (see /src/types.jl) from
`image_builder(ResNetHack)` to `image_builder(Metalhead.ResNet)`.

=#

const DISALLOWED_KWARGS = [:imsize, :inchannels, :nclasses]
const human_disallowed_kwargs = join(map(s->"`$s`", DISALLOWED_KWARGS), ", ", " and ")
const ERR_METALHEAD_DISALLOWED_KWARGS = ArgumentError(
"Keyword arguments $human_disallowed_kwargs are disallowed "*
"as their values are inferred from data. "
)

# # WRAPPING

struct MetalheadBuilder{F} <: MLJFlux.Builder
metalhead_constructor::F
args
kwargs
end

function Base.show(io::IO, ::MIME"text/plain", w::MetalheadBuilder)
println(io, "builder wrapping $(w.metalhead_constructor)")
if !isempty(w.args)
println(io, " args:")
for (i, arg) in enumerate(w.args)
println(io, " 1: $arg")
end
end
if !isempty(w.kwargs)
println(io, " kwargs:")
for kwarg in w.kwargs
println(io, " $(first(kwarg)) = $(last(kwarg))")
end
end
end

Base.show(io::IO, w::MetalheadBuilder) =
print(io, "image_builder($(repr(w.metalhead_constructor)), …)")


"""
image_builder(metalhead_constructor, args...; kwargs...)

Return an MLJFlux builder object based on the Metalhead.jl constructor/type
`metalhead_constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are
passed to the `MetalheadType` constructor at "build time", along with
the extra keyword specifiers `imsize=...`, `inchannels=...` and
`nclasses=...`, with values inferred from the data.

# Example

If in Metalhead.jl you would do

```julia
using Metalhead
model = ResNet(50, pretrain=true, inchannels=1, nclasses=10)
```

then in MLJFlux, it suffices to do

```julia
using MLJFlux, Metalhead
builder = image_builder(ResNet, 50, pretrain=true)
```

which can be used in `ImageClassifier` as in

```julia
clf = ImageClassifier(
builder=builder,
epochs=500,
optimiser=Flux.Adam(0.001),
loss=Flux.crossentropy,
batch_size=5,
)
```

The keyord arguments `imsize`, `inchannels` and `nclasses` are
dissallowed in `kwargs` (see above).

"""
function image_builder(
metalhead_constructor,
args...;
kwargs...
)
kw_names = keys(kwargs)
isempty(intersect(kw_names, DISALLOWED_KWARGS)) ||
throw(ERR_METALHEAD_DISALLOWED_KWARGS)
return MetalheadBuilder(metalhead_constructor, args, kwargs)
end

MLJFlux.build(
b::MetalheadBuilder,
rng,
n_in,
n_out,
n_channels
) = b.metalhead_constructor(
b.args...;
b.kwargs...,
imsize=n_in,
inchannels=n_channels,
nclasses=n_out
)

# See above "TODO" list.
function VGGHack(
depth::Integer=16;
imsize=(242,242),
inchannels=3,
nclasses=1000,
batchnorm=false,
pretrain=false,
)

# Adapted from
# https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165
# But we do not ignore `imsize`.

@assert(
depth in keys(Metalhead.vgg_config),
"depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))"
)
model = Metalhead.VGG(imsize;
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
inchannels,
batchnorm,
nclasses,
fcsize = 4096,
dropout = 0.5)
if pretrain && !batchnorm
Metalhead.loadpretrain!(model, string("VGG", depth))
elseif pretrain
Metalhead.loadpretrain!(model, "VGG$(depth)-BN)")
end
return model
end
19 changes: 18 additions & 1 deletion src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ end

# # FIT AND UPDATE

const ERR_BUILDER =
"Builder does not appear to build an architecture compatible with supplied data. "

true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng

function MLJModelInterface.fit(model::MLJFluxModel,
Expand All @@ -51,10 +54,24 @@ function MLJModelInterface.fit(model::MLJFluxModel,

rng = true_rng(model)
shape = MLJFlux.shape(model, X, y)
chain = build(model, rng, shape) |> move

chain = try
build(model, rng, shape) |> move
catch ex
@error ERR_BUILDER
end

penalty = Penalty(model)
data = move.(collate(model, X, y))

x = data |> first |> first
try
chain(x)
catch ex
@error ERR_BUILDER
throw(ex)
end

optimiser = deepcopy(model.optimiser)

chain, history = fit!(model.loss,
Expand Down
8 changes: 4 additions & 4 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic}

for Model in [:NeuralNetworkClassifier, :ImageClassifier]

default_builder_ex =
Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short()

ex = quote
mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic
builder::B
Expand All @@ -20,7 +23,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier]
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
end

function $Model(; builder::B = Short()
function $Model(; builder::B = $default_builder_ex
, finaliser::F = Flux.softmax
, optimiser::O = Flux.Optimise.Adam()
, loss::L = Flux.crossentropy
Expand Down Expand Up @@ -108,12 +111,9 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor]

end



const Regressor =
Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor}


MMI.metadata_pkg.(
(
NeuralNetworkRegressor,
Expand Down
43 changes: 43 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# # IMAGE COERCION

# Taken from ScientificTypes.jl to avoid as dependency.

_4Dcollection = AbstractArray{<:Real, 4}

function coerce(y::_4Dcollection, T2::Type{GrayImage})
size(y, 3) == 1 || error("Multiple color channels encountered. "*
"Perhaps you want to use `coerce(image_collection, ColorImage)`.")
y = dropdims(y, dims=3)
return [ColorTypes.Gray.(y[:,:,idx]) for idx=1:size(y,3)]
end

function coerce(y::_4Dcollection, T2::Type{ColorImage})
return [broadcast(ColorTypes.RGB, y[:,:,1, idx], y[:,:,2,idx], y[:,:,3, idx]) for idx=1:size(y,4)]
end


# # SYNTHETIC IMAGES

"""
make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05)

Return synthetic data of the form `(images, labels)` suitable for use
with MLJ's `ImageClassifier` model. All `images` are distortions of
`n_classes` fixed images. Two images with the same label correspond to
the same undistorted image.

"""
function make_images(rng; image_size=(6, 6), n_classes=33, n_images=50, color=false, noise=0.05)
n_channels = color ? 3 : 1
image_bag = map(1:n_classes) do _
rand(rng, Float32, image_size..., n_channels)
end
labels = rand(rng, 1:3, n_images)
images = map(labels) do j
image_bag[j] + noise*rand(rng, Float32, image_size..., n_channels)
end
T = color ? ColorImage : GrayImage
X = coerce(cat(images...; dims=4), T)
y = categorical(labels)
return X, y
end
8 changes: 5 additions & 3 deletions test/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ end
end

@testset_accelerated "@builder" accel begin
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(n_in, 4,
init = (out, in) -> randn(rng, out, in)),
Flux.Dense(4, n_out)))
builder = MLJFlux.@builder(Flux.Chain(Flux.Dense(
n_in,
4,
init = (out, in) -> randn(rng, out, in)
), Flux.Dense(4, n_out)))
rng = StableRNGs.StableRNG(123)
chain = MLJFlux.build(builder, rng, 5, 3)
ps = Flux.params(chain)
Expand Down
Loading