Skip to content

Commit 24a313e

Browse files
eliascarvjuliohm
andauthored
Add GLM.jl models (#5)
* Add GLM.jl models * Update src/StatsLearnModels.jl --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent c830cf3 commit 24a313e

File tree

5 files changed

+80
-2
lines changed

5 files changed

+80
-2
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ version = "0.2.0"
66
[deps]
77
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
88
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
9+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
911
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
1012
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1113

@@ -18,6 +20,8 @@ StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
1820
[compat]
1921
ColumnSelectors = "0.1"
2022
DecisionTree = "0.12"
23+
Distributions = "0.25"
24+
GLM = "1.9"
2125
MLJModelInterface = "1.9"
2226
TableTransforms = "1.15"
2327
Tables = "1.11"

src/StatsLearnModels.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,28 @@ using ColumnSelectors: selector
99
using TableTransforms: StatelessFeatureTransform
1010
import TableTransforms: applyfeat, isrevertible
1111

12+
import GLM
1213
import DecisionTree as DT
1314
using DecisionTree: AdaBoostStumpClassifier, DecisionTreeClassifier, RandomForestClassifier
1415
using DecisionTree: DecisionTreeRegressor, RandomForestRegressor
16+
using Distributions: UnivariateDistribution
1517

1618
include("interface.jl")
1719
include("models/decisiontree.jl")
20+
include("models/glm.jl")
1821
include("learn.jl")
1922

2023
export
21-
# models
24+
# DecisionTree.jl
2225
AdaBoostStumpClassifier,
2326
DecisionTreeClassifier,
2427
RandomForestClassifier,
2528
DecisionTreeRegressor,
2629
RandomForestRegressor,
30+
31+
# GLM.jl
32+
LinearRegressor,
33+
GeneralizedLinearRegressor,
2734

2835
# transform
2936
Learn

src/models/glm.jl

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
abstract type GLMModel end
2+
3+
struct LinearRegressor{K} <: GLMModel
4+
kwargs::K
5+
end
6+
7+
LinearRegressor(; kwargs...) = LinearRegressor(values(kwargs))
8+
9+
struct GeneralizedLinearRegressor{D<:UnivariateDistribution,L<:Union{GLM.Link,Nothing},K} <: GLMModel
10+
dist::D
11+
link::L
12+
kwargs::K
13+
end
14+
15+
GeneralizedLinearRegressor(dist::UnivariateDistribution, link=nothing; kwargs...) =
16+
GeneralizedLinearRegressor(dist, link, values(kwargs))
17+
18+
function fit(model::GLMModel, input, output)
19+
cols = Tables.columns(output)
20+
names = Tables.columnnames(cols)
21+
outcol = first(names)
22+
X = Tables.matrix(input)
23+
y = Tables.getcolumn(cols, outcol)
24+
fitted = _fit(model, X, y)
25+
FittedModel(model, (fitted, outcol))
26+
end
27+
28+
function predict(fmodel::FittedModel{<:GLMModel}, table)
29+
model, outcol = fmodel.cache
30+
X = Tables.matrix(table)
31+
= GLM.predict(model, X)
32+
(; outcol => ŷ) |> Tables.materializer(table)
33+
end
34+
35+
_fit(model::LinearRegressor, X, y) = GLM.lm(X, y; model.kwargs...)
36+
37+
function _fit(model::GeneralizedLinearRegressor, X, y)
38+
if isnothing(model.link)
39+
GLM.glm(X, y, model.dist; model.kwargs...)
40+
else
41+
GLM.glm(X, y, model.dist, model.link; model.kwargs...)
42+
end
43+
end

test/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4+
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
35
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
46
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
57
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/runtests.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using DataFrames
44
using Random
55
using Test
66

7+
using GLM: ProbitLink
8+
using Distributions: Binomial
9+
710
import MLJ, MLJDecisionTreeInterface
811

912
const SLM = StatsLearnModels
@@ -17,7 +20,7 @@ const SLM = StatsLearnModels
1720
@testset "interface" begin
1821
@testset "MLJ" begin
1922
Random.seed!(123)
20-
Tree = MLJ.@load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
23+
Tree = MLJ.@load(DecisionTreeClassifier, pkg = DecisionTree, verbosity = 0)
2124
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
2225
pred = SLM.predict(fmodel, input[test, :])
2326
accuracy = count(pred.target .== output.target[test]) / length(test)
@@ -32,6 +35,25 @@ const SLM = StatsLearnModels
3235
accuracy = count(pred.target .== output.target[test]) / length(test)
3336
@test accuracy > 0.9
3437
end
38+
39+
@testset "GLM" begin
40+
x = [1, 2, 3]
41+
y = [2, 4, 7]
42+
input = DataFrame(; ones=ones(length(x)), x)
43+
output = DataFrame(; y)
44+
model = LinearRegressor()
45+
fmodel = SLM.fit(model, input, output)
46+
pred = SLM.predict(fmodel, input)
47+
@test all(isapprox.(pred.y, output.y, atol=0.5))
48+
x = [1, 2, 2]
49+
y = [1, 0, 1]
50+
input = DataFrame(; ones=ones(length(x)), x)
51+
output = DataFrame(; y)
52+
model = GeneralizedLinearRegressor(Binomial(), ProbitLink())
53+
fmodel = SLM.fit(model, input, output)
54+
pred = SLM.predict(fmodel, input)
55+
@test all(isapprox.(pred.y, output.y, atol=0.5))
56+
end
3557
end
3658

3759
@testset "Learn" begin

0 commit comments

Comments
 (0)