Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReservoirComputing"
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
version = "0.12.15"
version = "0.12.16"
authors = ["Francesco Martinuzzi"]

[deps]
Expand All @@ -18,12 +18,14 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
[weakdeps]
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
RCCellularAutomataExt = "CellularAutomata"
RCLIBSVMExt = "LIBSVM"
RCLinearSolveExt = "LinearSolve"
RCMLJLinearModelsExt = "MLJLinearModels"
RCSparseArraysExt = "SparseArrays"

Expand All @@ -34,6 +36,7 @@ ConcreteStructs = "0.2.3"
DifferentialEquations = "7.16.1"
LIBSVM = "0.8"
LinearAlgebra = "1.10"
LinearSolve = "3.57.0"
LuxCore = "1.3.0"
MLJLinearModels = "0.9.2, 0.10"
NNlib = "0.9.26"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ reservoir computing models. More specifically the software offers:
[SparseArrays.jl](https://docs.julialang.org/en/v1/stdlib/SparseArrays/)
- Multiple training algorithms via [LIBSVM.jl](https://github.com/JuliaML/LIBSVM.jl)
and [MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl)
- Multiple linear solver via [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl)

## Installation

Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 1 addition & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ pages = [
"Building a model from scratch" => "tutorials/scratch.md",
"Chaos forecasting with an ESN" => "tutorials/lorenz_basic.md",
"Fitting a Next Generation Reservoir Computer" => "tutorials/ngrc.md",
#"Using Different Training Methods" => "esn_tutorials/different_training.md",
"Deep Echo State Networks" => "tutorials/deep_esn.md",
"Training Reservoir Computing Models" => "tutorials/train.md",
#"Hybrid Echo State Networks" => "tutorials/hybrid.md",
"Reservoir Computing with Cellular Automata" => "tutorials/reca.md",
"Saving and loading models" => "tutorials/saveload.md",
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ or `dev` the package.
[SparseArrays.jl](https://docs.julialang.org/en/v1/stdlib/SparseArrays/)
- Multiple training algorithms via [LIBSVM.jl](https://github.com/JuliaML/LIBSVM.jl)
and [MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl)
- Multiple linear solver via [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl)

## Contributing

Expand Down
121 changes: 121 additions & 0 deletions docs/src/tutorials/train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Training Reservoir Computing Models

Training reservoir computing (RC) models usually means solving a linear
regression problem. ReservoirComputing.jl offers multiple stratedies to
provide a readout; in this page we will show the basics, while also pointing out
the possible extensions.

## Training in ReservoirComputing.jl: Ridge Regression

The most simple training of RC models is through ridge regression.
Given the widepread adoption of this training mechanism, ridge regression is the
default training algorithm for RC models in the library.

```@example training
using ReservoirComputing
using Random
Random.seed!(42)
rng = MersenneTwister(42)

input_data = rand(Float32, 3, 100)
target_data = rand(Float32, 5, 100)

model = ESN(3, 100, 5)
ps, st = setup(rng, model)
ps, st = train!(model, input_data, target_data, ps, st,
StandardRidge(); # default
solver = QRSolver()) # default
```

In this call you can see that there are two possible knobs to be modified: the
loss function, in this case ridge, and the solver, in this case the build in QR
factorization. In the remaining part of this tutorial we will see how it is possible
to change either.

## Changing Ridge Regression Solver

Building on SciML's [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl), it is
possible to leverage multiple solvers for the ridge problem. For instance, building
on the previous example:

```@example training
using LinearSolve

ps, st = train!(model, input_data, target_data, ps, st,
StandardRidge(); # default
solver = SVDFactorization()) # from LinearSolve
```

or

```@example training
ps, st = train!(model, input_data, target_data, ps, st,
StandardRidge(); # default
solver = QRFactorization()) # from LinearSolve
```

For a detailed explanation of the different solvers, as well as a complete list of them,
we suggest visiting the appropriate page in LinearSolve's
[documentation](https://docs.sciml.ai/LinearSolve/stable/solvers/solvers/)

## Changing Linear Regression Problem

Linear regression is a general problem, which can be espressed through multiple different
loss functions. While ridge regression is the most common in RC, due to its closed form,
there are multiple other available. ReservoirComputing.jl leverages
[MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl) to access all the methods
available from that library.

!!! warn

Currently MLJLinearModels.jl only supports `Float64`. If a certain precision is of the
upmost importance to you, please refrain from using this external package

The train function can be called as before, only this time you can specify different models
and different solvers for the linear regression problem:

```@example training
using MLJLinearModels

ps, st = train!(model, input_data, target_data, ps, st,
LassoRegression(fit_intercept=false); # from MLJLinearModels
solver = ProxGrad()) # from MLJLinearModels
```

Make sure to check the MLJLinearModels documentation pages for the available
[models](https://juliaai.github.io/MLJLinearModels.jl/stable/models/) and
[solvers](https://juliaai.github.io/MLJLinearModels.jl/stable/solvers/). Please note that
not all solvers can be used on all the models.

!!! note

Currently the support for MLJLinearModels.jl is limited to regressors with
`fit_intercept=false`. We are working on a solution, but until then you will always
need to specify it on the regressor.

## Support Vector Regression

ReservoirComputing.jl also allows users to train RC models with support vector regression
through [LIBSVM.jl](https://github.com/JuliaML/LIBSVM.jl). However, the majority of builtin
models in the library uses a [`LinearReadout`](@ref) by default, which can only be trained with
linear regression. In order to use support vector regression, one needs to build a model
with [`SVMReadout`](@ref)

```@example training
using LIBSVM

model = ReservoirComputer(
StatefulLayer(ESNCell(3=>100)),
SVMReadout(100=>5)
)

ps, st = setup(rng, model)
```

We can now train our new `model` similarly to before:

```@example training
ps, st = train!(model, input_data, target_data, ps, st,
EpsilonSVR() # from LIBSVM
)
```
35 changes: 29 additions & 6 deletions ext/RCLIBSVMExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ module RCLIBSVMExt

using LIBSVM
using ReservoirComputing:
SVMReadout, addreadout!, ReservoirChain
import ReservoirComputing: train
SVMReadout, ReservoirChain, ReservoirComputer
import ReservoirComputing: train, addreadout!

function train(
svr::LIBSVM.AbstractSVR,
states::AbstractArray, target::AbstractArray
states::AbstractMatrix, target::AbstractMatrix
)
@assert size(states, 2) == size(target, 2) "states and target must share columns."
perm_states = permutedims(states)
Expand Down Expand Up @@ -37,7 +37,7 @@ function (svmro::SVMReadout)(inp::AbstractArray, ps, st::NamedTuple)
vec_like = false
if ndims(inp) == 1
reshaped_inp = reshape(inp, 1, :)
num_imp = 1
num_inp = 1
vec_like = true
elseif ndims(inp) == 2
if size(inp, 2) == 1
Expand All @@ -46,14 +46,14 @@ function (svmro::SVMReadout)(inp::AbstractArray, ps, st::NamedTuple)
vec_like = true
else
reshaped_inp = permutedims(inp)
num_imp = size(reshaped_inp, 1)
num_inp = size(reshaped_inp, 1)
end
else
throw(ArgumentError("SVMReadout expects 1D or 2D input; got size $(size(inp))"))
end

if models isa AbstractVector
out_data = Array{float(eltype(reshaped_inp))}(undef, svmro.out_dims, num_imp)
out_data = Array{float(eltype(reshaped_inp))}(undef, svmro.out_dims, num_inp)
for (idx, model) in enumerate(models)
single_out = LIBSVM.predict(models[idx], reshaped_inp)
out_data[idx, :] = single_out
Expand All @@ -70,4 +70,27 @@ function (svmro::SVMReadout)(inp::AbstractArray, ps, st::NamedTuple)
end
end

_set_readout_models(ps_readout::NamedTuple, models) = merge(ps_readout, (; models = models))

function addreadout!(
rc::ReservoirComputer,
models, # model or vector of models
ps::NamedTuple,
st::NamedTuple
)
# Only valid if the model's readout is actually SVMReadout
if rc.readout isa SVMReadout
@assert hasproperty(ps, :readout)
new_readout = _set_readout_models(ps.readout, models)
return merge(ps, (readout = new_readout,)), st
end

throw(
ArgumentError(
"This training method produced a non-matrix readout (e.g. LIBSVM models), " *
"but the model readout is $(typeof(rc.readout)). Use SVMReadout as the readout layer."
)
)
end

end # module
30 changes: 30 additions & 0 deletions ext/RCLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module RCLinearSolveExt
using LinearAlgebra: mul!, I
using LinearSolve: LinearProblem, init, solve!, SciMLLinearSolveAlgorithm
using ReservoirComputing: StandardRidge
import ReservoirComputing: _train_ridge

function _train_ridge(
solver::SciMLLinearSolveAlgorithm, sr::StandardRidge,
states::AbstractMatrix, targets::AbstractMatrix; kwargs...
)

nfeat, T = size(states)
nout, T2 = size(targets)
T == T2 || throw(DimensionMismatch("states has T=$T samples, targets has T=$T2"))
λ = convert(eltype(states), sr.reg)
A = states * states' + λ * I
b = zeros(eltype(states), nfeat)
prob = LinearProblem(A, b)
linsolve = init(prob, solver; kwargs...)
Wt = zeros(eltype(states), nfeat, nout)
for idx in 1:nout
mul!(linsolve.b, states, targets[idx, :])
sol = solve!(linsolve)
Wt[:, idx] .= sol.u
end

return permutedims(Wt) # (n_outputs, n_features)
end

end #module
2 changes: 1 addition & 1 deletion src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export block_diagonal, chaotic_init, cycle_jumps, delay_line, delayline_backward
diagonal_init
export add_jumps!, backward_connection!, delay_line!, reverse_simple_cycle!,
scale_radius!, self_loop!, simple_cycle!, permute_matrix!
export train, train!, predict, resetcarry!, polynomial_monomials
export train, train!, predict, resetcarry!, polynomial_monomials, QRSolver
export ES2N, ESN, EuSN, DeepESN, DelayESN, HybridESN, EIESN, AdditiveEIESN, InputDelayESN, StateDelayESN
export NGRC
#ext
Expand Down
16 changes: 14 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ end

_set_readout(ps, m::ReservoirChain, W) = first(addreadout!(m, W, ps, NamedTuple()))

abstract type AbstractReservoirComputingSolver end

struct QRSolver <: AbstractReservoirComputingSolver end

@doc raw"""
train(train_method, states, target_data; kwargs...)

Expand Down Expand Up @@ -75,11 +79,19 @@ additional changes.
value as the forward method only.
"""
function train(
sr::StandardRidge, states::AbstractArray, target_data::AbstractArray; kwargs...
sr::StandardRidge, states::AbstractMatrix, target_data::AbstractMatrix;
solver = QRSolver(), kwargs...
)
return _train_ridge(solver, sr, states, target_data; kwargs...)
end

function _train_ridge(
::QRSolver, sr::StandardRidge,
states::AbstractMatrix, target_data::AbstractMatrix; kwargs...
)
n_states = size(states, 1)
A = [states'; sqrt(sr.reg) * I(n_states)]
b = [target_data'; zeros(n_states, size(target_data, 1))]
b = [target_data'; zeros(eltype(target_data), n_states, size(target_data, 1))]
F = qr(A)
Wt = F \ b
output_layer = Matrix(Wt')
Expand Down
Loading