Skip to content

Commit

Permalink
Introducing more MLJ-compliant API for KNNDTW and MiniRocket models
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninkriz committed Oct 4, 2023
1 parent a6422ab commit d0354c5
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 53 deletions.
6 changes: 3 additions & 3 deletions src/KNNDTW/dtw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function DTWItakura{T}(;
end

"Function to calucate vanilla DTW distance between `x` and `y`."
function dtw!(model::DTW{T}, x::Tarr, y::Tarr)::T where {T <: AbstractFloat, Tarr <: AbstractVector{T}}
function dtw!(model::DTW{T}, x::AbstractVector{T}, y::AbstractVector{T})::T where {T <: AbstractFloat}
row_count, col_count = length(x), length(y)

# Julia is column major, to make things faster let longer timeseries = columns and shorter timeseries = rows
Expand Down Expand Up @@ -95,7 +95,7 @@ function dtw!(model::DTW{T}, x::Tarr, y::Tarr)::T where {T <: AbstractFloat, Tar
end

"Function to calucate Sakoe-Chiba band limited DTW distance between `x` and `y`."
function dtw!(model::DTWSakoeChiba{T}, x::Tarr, y::Tarr)::T where {T <: AbstractFloat, Tarr <: AbstractVector{T}}
function dtw!(model::DTWSakoeChiba{T}, x::AbstractVector{T}, y::AbstractVector{T})::T where {T <: AbstractFloat}
row_count, col_count = length(x), length(y)

# Julia is column major, to make things faster let longer timeseries = columns and shorter timeseries = rows
Expand Down Expand Up @@ -134,7 +134,7 @@ function dtw!(model::DTWSakoeChiba{T}, x::Tarr, y::Tarr)::T where {T <: Abstract
end

"Function to calucate Itakura parallelogram limited DTW distance between `x` and `y`."
function dtw!(model::DTWItakura{T}, x::Tarr, y::Tarr)::T where {T <: AbstractFloat, Tarr <: AbstractVector{T}}
function dtw!(model::DTWItakura{T}, x::AbstractVector{T}, y::AbstractVector{T})::T where {T <: AbstractFloat}
row_count, col_count = length(x), length(y)

# Julia is column major, to make things faster let longer timeseries = columns and shorter timeseries = rows
Expand Down
16 changes: 8 additions & 8 deletions src/KNNDTW/knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ MLJModelInterface.@mlj_model mutable struct KNNDTWModel <: MLJModelInterface.Pro
bounding::LBType = LBNone()
end

function MLJModelInterface.reformat(::KNNDTWModel, X::AbstractVector{<:AbstractVector{<:AbstractFloat}})
return (X,)
end

function MLJModelInterface.reformat(::KNNDTWModel, (X, type)::Tuple{<:AbstractMatrix{<:AbstractFloat}, Symbol})
@assert type in (:row_based, :column_based) "Unsupported matrix format"
matrix = MLJModelInterface.matrix(X, transpose = type == :row_based)

return if type == :row_based
@info "Copying data from the row based matrix"
([matrix[row, :] for row in axes(matrix, 1)],)
([X[row, :] for row in axes(X, 1)],)
else
([view(matrix, :, col) for col in axes(X, 2)],)
([view(X, :, col) for col in axes(X, 2)],)
end
end

MLJModelInterface.reformat(::KNNDTWModel, X::AbstractVector{<:AbstractVector{<:AbstractFloat}}) = (X,)
MLJModelInterface.reformat(::KNNDTWModel, X::AbstractVector{<:AbstractVector{<:AbstractFloat}}, y) = (X, MLJModelInterface.categorical(y))

MLJModelInterface.reformat(m::KNNDTWModel, X::AbstractMatrix{<:AbstractFloat}) = MLJModelInterface.reformat(m, (X, :row_based))
MLJModelInterface.reformat(m::KNNDTWModel, X::AbstractMatrix{<:AbstractFloat}, y) = (MLJModelInterface.reformat(m, (X, :row_based))..., MLJModelInterface.categorical(y))

MLJModelInterface.reformat(m::KNNDTWModel, (X, type)::Tuple{<:AbstractMatrix{<:AbstractFloat}, Symbol}, y) = (MLJModelInterface.reformat(m, (X, type))..., MLJModelInterface.categorical(y))

MLJModelInterface.selectrows(::KNNDTWModel, I, Xvec) = (view(Xvec, I),)
Expand Down
4 changes: 2 additions & 2 deletions src/KNNDTW/lb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function LBKeogh{T}(;
end

"Lower bound method that always returns zero."
function lower_bound!(::LBNone, ::Tarr, ::Tarr; update::Bool = true)::T where {T <: AbstractFloat, Tarr <: AbstractVector{T}}
function lower_bound!(::LBNone, ::AbstractVector{T}, ::AbstractVector{T}; update::Bool = true)::T where {T <: AbstractFloat}
zero(T)
end

Expand All @@ -40,7 +40,7 @@ Function implementing the lower bound LB_Keogh method.
Set update=true to update the envelope forcefully.
"
function lower_bound!(lb::LBKeogh{T}, enveloped::Tarr, query::Tarr; update::Bool = true)::T where {T <: AbstractFloat, Tarr <: AbstractVector{T}}
function lower_bound!(lb::LBKeogh{T}, enveloped::AbstractVector{T}, query::AbstractVector{T}; update::Bool = true)::T where {T <: AbstractFloat}
@assert length(enveloped) === length(query) "Enveloped serires and query series must be of the same length"
@assert length(enveloped) >= lb.radius + 1 "Window raidus can not be larger than the series itself"

Expand Down
9 changes: 7 additions & 2 deletions src/MiniRocket/mr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,16 @@ MLJModelInterface.@mlj_model mutable struct MiniRocketModel <: MLJModelInterface
shuffled::Bool = false
end

function MLJModelInterface.reformat(::MiniRocketModel, (X, type))
function MLJModelInterface.reformat(::MiniRocketModel, (X, type)::Tuple{<:AbstractMatrix{<:AbstractFloat}, Symbol})
@assert type in (:row_based, :column_based)

(MLJModelInterface.matrix(X, transpose = type == :row_based),)
end

function MLJModelInterface.reformat(::MiniRocketModel, X)
(MLJModelInterface.matrix(X, transpose = true),)
end

MLJModelInterface.selectrows(::MiniRocketModel, I, Xmatrix) = (view(Xmatrix, :, I),)

"Function to train MiniRocket transformer."
Expand Down Expand Up @@ -286,7 +291,7 @@ function MLJModelInterface.transform(
dilations = fitresult[1],
num_features_per_dilation = fitresult[2],
biases = fitresult[3],
)
) |> transpose
end

"Loads fit paramters of the MiniRocket transformer."
Expand Down
38 changes: 32 additions & 6 deletions src/TimeSeriesClassification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,22 @@ MLJModelInterface.metadata_pkg.(

MLJModelInterface.metadata_model(
MiniRocketModel,
input_scitype = Tuple{AbstractMatrix{<:MLJModelInterface.Continuous}, MLJModelInterface.Unknown},
input_scitype = Union{
AbstractMatrix{<:MLJModelInterface.Continuous},
Tuple{AbstractMatrix{<:MLJModelInterface.Continuous}, MLJModelInterface.Unknown}
},
output_scitype = AbstractMatrix{<:MLJModelInterface.Continuous},
load_path = "TimeSeriesClassification.MiniRocketModel",
)

MLJModelInterface.metadata_model(
KNNDTWModel,
input_scitype = AbstractVector{<:AbstractVector{<:MLJModelInterface.Continuous}},
input_scitype = Union{
AbstractVector{<:AbstractVector{<:MLJModelInterface.Continuous}},
Tuple{AbstractVector{<:AbstractVector{<:MLJModelInterface.Continuous}}, MLJModelInterface.Unknown},
AbstractMatrix{<:MLJModelInterface.Continuous},
Tuple{AbstractMatrix{<:MLJModelInterface.Continuous}, MLJModelInterface.Unknown}
},
output_scitype = AbstractMatrix{<:MLJModelInterface.Finite},
load_path = "TimeSeriesClassification.KNNDTWModel",
)
Expand Down Expand Up @@ -73,6 +81,8 @@ A model parameters are built using [`MiniRocket._MiniRocket.fit`](@ref):
model_params = fit(X_train; num_features = model.num_features, max_dilations_per_kernel = model.max_dilations_per_kernel, shuffled = model.shuffled, rng = model.rng)
```
where `X_train` is a column based matrix of training data.
#### Transforming data
The gathered model parameters can be used for transforming other data using [`MiniRocket._MiniRocket.transform`](@ref):
Expand All @@ -82,24 +92,34 @@ dilations, num_features_per_dilation, biases = model_params
X_transformed = transform(X_new; dilations = dilations, num_features_per_dilation = num_features_per_dilation, biases = biases)
```
where `X_train` is a column based matrix of training data, `X_new` is a column based matrix of data to be transformed and `X_transformed` is a column based matrix of transformed data.
### MLJ model API
Crate an instance with default hyperparameters or override them with your own using [`MiniRocket._MiniRocket.MiniRocketModel`](@ref) and build a MLJ machine:
```julia
minirocket_model = MiniRocketModel()
mach = machine(minirocket_model, X_train)
mach = machine(minirocket_model, (X_train, :row_based))
# or when X is column based
mach = machine(minirocket_model, (X_train, :column_based))
```
You must specify if the data provided are row or column based using the `:column_based` and `:row_based` parameter.
`X_train` is a matrix of training data.
You can specify if the data provided are row or column based using the `:column_based` and `:row_based` parameter.
Column major format is preferred for performance since Julia is a column major language.
#### Training model
Train the machine using `fit!(mach)`.
#### Transforming data
Transform the data using `transform(mach, (X_new, :column_based))` or using `transform(mach, (X_new, :row_based))` in case of row based data.
Transform the data using `transform(mach, X_new)` or `transform(mach, (X_new, :row_based))` or using `transform(mach, (X_new, :column_based))` in case of row based data.
The result is always row major (column major, but with `tranpose` applied). To convert the result to column major format use `transpose`, which should be without any extra computational cost.
"""
MiniRocketModel
Expand Down Expand Up @@ -137,18 +157,24 @@ Crate an instance with default hyperparameters or override them with your own us
```julia
knndtw_model = KNNDTWModel()
mach = machine(knndtw_model, X_train, Y_train)
mach = machine(knndtw_model, (X_train, :row_based), Y_train)
# or when X is column based
mach = machine(knndtw_model, (X_train, :column_based), Y_train)
```
You must specify if the data provided are row or column based using the `:column_based` and `:row_based` parameter.
`X_train` is either a matrix or a vector of vectors of training data.
You can specify if the matrix provided os row or column based using the `:column_based` and `:row_based` parameter.
Column major format is preferred for performance since Julia is a column major language.
#### Training model
Train the machine using `fit!(mach)`.
#### Predicting
To predict "probability" of a class you can use `predict` like `predict(mach, (X_new, :column_based))` for columns based data or using `predict(mach, (X_new, :row_based))` in case of row based data.
To predict "probability" of a class you can use `predict` like `predict(mach, X_new)` or `predict(mach, (X_new, :row_based))` for row major data or using `predict(mach, (X_new, :column_based))` in case of column major data.
To classify the data (to get the most probable class) you can use `predict_mode` in a similar fashion.
"""
Expand Down
24 changes: 24 additions & 0 deletions test/KNNDTW/consts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
const TS1::Vector{Float64} = [0.57173714, 0.03585991, 0.16263380, 0.63153396, 0.00599358, 0.63256182, 0.85341386, 0.87538411, 0.35243848, 0.27466851]
const TS2::Vector{Float64} = [0.17281271, 0.54244937, 0.35081248, 0.83846642, 0.74942411]
const TS3::Vector{Float64} = [1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 0.16263380, 0.63153396, 0.00599358, 0.63256182]
const TS4::Vector{Float64} = [0.57173714, 0.03585991, 0.16263380, 0.00599358, 0.00599358, 0.00599358, 0.85341386, 0.87538411, 0.35243848, 0.27466851]

const TRAIN_X::Matrix{Float64} = [
0 1 3
0 2 0
0 1 3
1 0 0
2 0 3
1 0 0
0 0 3
]
const TRAIN_Y::CategoricalArray = categorical(["a", "a", "b"])
const TEST_X::Matrix{Float64} = [
0 8
0 1
1 8
2 1
1 8
0 1
0 8
]
87 changes: 57 additions & 30 deletions test/KNNDTW/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,10 @@ using Test: @testset, @test, @test_throws
using CategoricalArrays: CategoricalArray, categorical
using CategoricalDistributions: pdf
using MLJBase: machine, fit!, fitted_params, predict, predict_mode
import MLJModelInterface

include("consts.jl")

const TS1::Vector{Float64} = [0.57173714, 0.03585991, 0.16263380, 0.63153396, 0.00599358, 0.63256182, 0.85341386, 0.87538411, 0.35243848, 0.27466851]
const TS2::Vector{Float64} = [0.17281271, 0.54244937, 0.35081248, 0.83846642, 0.74942411]
const TS3::Vector{Float64} = [1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 0.16263380, 0.63153396, 0.00599358, 0.63256182]
const TS4::Vector{Float64} = [0.57173714, 0.03585991, 0.16263380, 0.00599358, 0.00599358, 0.00599358, 0.85341386, 0.87538411, 0.35243848, 0.27466851]

const TRAIN_X::Matrix{Float64} = [
0 1 3
0 2 0
0 1 3
1 0 0
2 0 3
1 0 0
0 0 3
]
const TRAIN_Y::CategoricalArray = categorical(["a", "a", "b"])
const TEST_X::Matrix{Float64} = [
0 8
0 1
1 8
2 1
1 8
0 1
0 8
]

@testset "KNNDTW.jl - dtw!() - Full" begin
model = KNNDTW.DTW{eltype(TS1)}()
Expand Down Expand Up @@ -92,7 +70,7 @@ end
end

@testset "KNNDTW.jl - KNN - K=1" begin
nn = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TS1)}())
nn = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach = machine(nn, (TRAIN_X, :column_based), TRAIN_Y)
fit!(mach, verbosity=0)
Expand All @@ -103,7 +81,7 @@ end
end

@testset "KNNDTW.jl - KNN - K=3 - Xnew smaller than X" begin
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TS1)}())
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach = machine(nn, (TRAIN_X, :column_based), TRAIN_Y)
fit!(mach, verbosity=0)
Expand All @@ -114,7 +92,7 @@ end
end

@testset "KNNDTW.jl - KNN - K=3 - Xnew larger than X" begin
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TS1)}())
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach = machine(nn, (TRAIN_X, :column_based), TRAIN_Y)
fit!(mach, verbosity=0)
Expand All @@ -125,7 +103,7 @@ end
end

@testset "KNNDTW.jl - KNN - K=3 - predict_mode" begin
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TS1)}())
nn = KNNDTW.KNNDTWModel(K=3, weights=:distance, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach = machine(nn, (TRAIN_X, :column_based), TRAIN_Y)
fit!(mach, verbosity=0)
Expand All @@ -143,8 +121,8 @@ end
end

@testset "KNNDTW.jl - KNN - K=1 - repeated is same" begin
nn1 = KNNDTW.KNNDTWModel(K=2, distance=KNNDTW.DTW{eltype(TS1)}())
nn2 = KNNDTW.KNNDTWModel(K=2, distance=KNNDTW.DTW{eltype(TS1)}())
nn1 = KNNDTW.KNNDTWModel(K=2, distance=KNNDTW.DTW{eltype(TRAIN_X)}())
nn2 = KNNDTW.KNNDTWModel(K=2, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach1 = machine(nn1, (TRAIN_X, :column_based), TRAIN_Y)
mach2 = machine(nn2, (TRAIN_X, :column_based), TRAIN_Y)
Expand All @@ -157,3 +135,52 @@ end

@test pred1 == pred2
end

@testset "KNNDTW.jl - row major and column major inputs" begin
data_col = TRAIN_X
data_row = permutedims(TRAIN_X)
data_vec = [view(TRAIN_X, :, col) for col in axes(TRAIN_X, 2)]

m_machine_col = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())
m_machine_row1 = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())
m_machine_row2 = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())
m_machine_vec = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())
m_model = KNNDTW.KNNDTWModel(K=1, distance=KNNDTW.DTW{eltype(TRAIN_X)}())

mach_col = machine(m_machine_col, (data_col, :column_based), TRAIN_Y)
mach_row1 = machine(m_machine_row1, (data_row, :row_based), TRAIN_Y)
mach_row2 = machine(m_machine_row2, data_row, TRAIN_Y)
mach_vec = machine(m_machine_vec, data_vec, TRAIN_Y)

fit!(mach_col, verbosity=0)
fit!(mach_row1, verbosity=0)
fit!(mach_row2, verbosity=0)
fit!(mach_vec, verbosity=0)
fp_model = MLJModelInterface.fit(m_model, false, data_vec, TRAIN_Y)[1]

pred_col_c = predict_mode(mach_col, (data_col, :column_based))
pred_col_r1 = predict_mode(mach_col, (data_row, :row_based))
pred_col_r2 = predict_mode(mach_col, data_row)
pred_col_v = predict_mode(mach_col, data_vec)
@test pred_col_c == pred_col_r1 == pred_col_r2 == pred_col_v

pred_row1_c = predict_mode(mach_row1, (data_col, :column_based))
pred_row1_r1 = predict_mode(mach_row1, (data_row, :row_based))
pred_row1_r2 = predict_mode(mach_row1, data_row)
pred_row1_v = predict_mode(mach_row1, data_vec)
@test pred_row1_c == pred_row1_r1 == pred_row1_r2 == pred_row1_v

pred_row2_c = predict_mode(mach_row2, (data_col, :column_based))
pred_row2_r1 = predict_mode(mach_row2, (data_row, :row_based))
pred_row2_r2 = predict_mode(mach_row2, data_row)
pred_row2_v = predict_mode(mach_row2, data_vec)
@test pred_row2_c == pred_row2_r1 == pred_row2_r2 == pred_row2_v

pred_vec_c = predict_mode(mach_vec, (data_col, :column_based))
pred_vec_r1 = predict_mode(mach_vec, (data_row, :row_based))
pred_vec_r2 = predict_mode(mach_vec, data_row)
pred_vec_v = predict_mode(mach_vec, data_vec)
@test pred_vec_c == pred_vec_r1 == pred_vec_r2 == pred_vec_v

@test pred_col_c == pred_col_r1 == pred_col_r2 == pred_col_v == pred_row1_c == pred_row1_r1 == pred_row1_r2 == pred_row1_v == pred_row2_c == pred_row2_r1 == pred_row2_r2 == pred_row2_v == pred_vec_c == pred_vec_r1 == pred_vec_r2 == pred_vec_v
end
Loading

0 comments on commit d0354c5

Please sign in to comment.