1
1
context(" Models from previous versions of XGBoost can be loaded" )
2
2
3
3
metadata <- list (
4
- kRounds = 2 ,
4
+ kRounds = 4 ,
5
5
kRows = 1000 ,
6
6
kCols = 4 ,
7
7
kForests = 2 ,
@@ -10,87 +10,130 @@ metadata <- list(
10
10
)
11
11
12
12
run_model_param_check <- function (config ) {
13
- testthat :: expect_equal(config $ learner $ learner_model_param $ num_feature , ' 4' )
14
- testthat :: expect_equal(config $ learner $ learner_train_param $ booster , ' gbtree' )
13
+ testthat :: expect_equal(config $ learner $ learner_model_param $ num_feature , " 4" )
14
+ testthat :: expect_equal(config $ learner $ learner_train_param $ booster , " gbtree" )
15
+ }
16
+
17
+ get_n_rounds <- function (model_file ) {
18
+ is_10 <- grepl(" 1.0.0rc1" , model_file , fixed = TRUE )
19
+ if (is_10 ) {
20
+ 2
21
+ } else {
22
+ metadata $ kRounds
23
+ }
15
24
}
16
25
17
26
get_num_tree <- function (booster ) {
18
27
dump <- xgb.dump(booster )
19
- m <- regexec(' booster\\ [[0-9]+\\ ]' , dump , perl = TRUE )
28
+ m <- regexec(" booster\\ [[0-9]+\\ ]" , dump , perl = TRUE )
20
29
m <- regmatches(dump , m )
21
- num_tree <- Reduce(' + ' , lapply(m , length ))
22
- return ( num_tree )
30
+ num_tree <- Reduce(" + " , lapply(m , length ))
31
+ num_tree
23
32
}
24
33
25
- run_booster_check <- function (booster , name ) {
34
+ run_booster_check <- function (booster , model_file ) {
26
35
config <- xgb.config(booster )
27
36
run_model_param_check(config )
28
- if (name == ' cls' ) {
29
- testthat :: expect_equal(get_num_tree(booster ),
30
- metadata $ kForests * metadata $ kRounds * metadata $ kClasses )
31
- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5 )
32
- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' multi:softmax' )
33
- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ),
34
- metadata $ kClasses )
35
- } else if (name == ' logitraw' ) {
36
- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
37
- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ), 0 )
38
- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' binary:logitraw' )
39
- } else if (name == ' logit' ) {
40
- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
41
- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ), 0 )
42
- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' binary:logistic' )
43
- } else if (name == ' ltr' ) {
44
- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
45
- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' rank:ndcg' )
37
+ is_model <- function (typ ) {
38
+ grepl(typ , model_file , fixed = TRUE )
39
+ }
40
+ n_rounds <- get_n_rounds(model_file = model_file )
41
+ if (is_model(" cls" )) {
42
+ testthat :: expect_equal(
43
+ get_num_tree(booster ), metadata $ kForests * n_rounds * metadata $ kClasses
44
+ )
45
+ testthat :: expect_equal(
46
+ as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5
47
+ )
48
+ testthat :: expect_equal(
49
+ config $ learner $ learner_train_param $ objective , " multi:softmax"
50
+ )
51
+ testthat :: expect_equal(
52
+ as.numeric(config $ learner $ learner_model_param $ num_class ),
53
+ metadata $ kClasses
54
+ )
55
+ } else if (is_model(" logitraw" )) {
56
+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
57
+ testthat :: expect_equal(
58
+ as.numeric(config $ learner $ learner_model_param $ num_class ), 0
59
+ )
60
+ testthat :: expect_equal(
61
+ config $ learner $ learner_train_param $ objective , " binary:logitraw"
62
+ )
63
+ } else if (is_model(" logit" )) {
64
+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
65
+ testthat :: expect_equal(
66
+ as.numeric(config $ learner $ learner_model_param $ num_class ), 0
67
+ )
68
+ testthat :: expect_equal(
69
+ config $ learner $ learner_train_param $ objective , " binary:logistic"
70
+ )
71
+ } else if (is_model(" ltr" )) {
72
+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
73
+ testthat :: expect_equal(
74
+ config $ learner $ learner_train_param $ objective , " rank:ndcg"
75
+ )
76
+ } else if (is_model(" aft" )) {
77
+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
78
+ testthat :: expect_equal(
79
+ config $ learner $ learner_train_param $ objective , " survival:aft"
80
+ )
46
81
} else {
47
- testthat :: expect_equal(name , ' reg' )
48
- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
49
- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5 )
50
- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' reg:squarederror' )
82
+ testthat :: expect_true(is_model(" reg" ))
83
+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
84
+ testthat :: expect_equal(
85
+ as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5
86
+ )
87
+ testthat :: expect_equal(
88
+ config $ learner $ learner_train_param $ objective , " reg:squarederror"
89
+ )
51
90
}
52
91
}
53
92
54
93
test_that(" Models from previous versions of XGBoost can be loaded" , {
55
- bucket <- ' xgboost-ci-jenkins-artifacts'
56
- region <- ' us-west-2'
57
- file_name <- ' xgboost_r_model_compatibility_test. zip'
94
+ bucket <- " xgboost-ci-jenkins-artifacts"
95
+ region <- " us-west-2"
96
+ file_name <- " xgboost_model_compatibility_tests-3.0.2. zip"
58
97
zipfile <- tempfile(fileext = " .zip" )
59
98
extract_dir <- tempdir()
60
- download.file(paste(' https://' , bucket , ' .s3-' , region , ' .amazonaws.com/' , file_name , sep = ' ' ),
61
- destfile = zipfile , mode = ' wb' , quiet = TRUE )
99
+ result <- tryCatch(
100
+ {
101
+ download.file(
102
+ paste(
103
+ " https://" , bucket , " .s3-" , region , " .amazonaws.com/" , file_name ,
104
+ sep = " "
105
+ ),
106
+ destfile = zipfile , mode = " wb" , quiet = TRUE
107
+ )
108
+ zipfile
109
+ },
110
+ error = function (e ) {
111
+ print(e )
112
+ NA_character_
113
+ }
114
+ )
115
+ if (is.na(result )) {
116
+ print(" Failed to download old models." )
117
+ return ()
118
+ }
119
+
62
120
unzip(zipfile , exdir = extract_dir , overwrite = TRUE )
63
- model_dir <- file.path(extract_dir , ' models' )
121
+ model_dir <- file.path(extract_dir , " models" )
64
122
65
- pred_data <- xgb.DMatrix(matrix (c(0 , 0 , 0 , 0 ), nrow = 1 , ncol = 4 ), nthread = 2 )
123
+ pred_data <- xgb.DMatrix(
124
+ matrix (c(0 , 0 , 0 , 0 ), nrow = 1 , ncol = 4 ),
125
+ nthread = 2
126
+ )
66
127
67
128
lapply(list.files(model_dir ), function (x ) {
68
129
model_file <- file.path(model_dir , x )
69
- m <- regexec(" xgboost-([0-9\\ .]+)\\ .([a-z]+)\\ .[a-z]+" , model_file , perl = TRUE )
70
- m <- regmatches(model_file , m )[[1 ]]
71
- model_xgb_ver <- m [2 ]
72
- name <- m [3 ]
73
- is_rds <- endsWith(model_file , ' .rds' )
74
- is_json <- endsWith(model_file , ' .json' )
75
- # TODO: update this test for new RDS format
76
- if (is_rds ) {
77
- return (NULL )
78
- }
79
- # Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
80
- if (is_rds && compareVersion(model_xgb_ver , ' 1.1.1.1' ) < 0 ) {
81
- booster <- readRDS(model_file )
82
- expect_warning(predict(booster , newdata = pred_data ))
83
- booster <- readRDS(model_file )
84
- expect_warning(run_booster_check(booster , name ))
85
- } else {
86
- if (is_rds ) {
87
- booster <- readRDS(model_file )
88
- } else {
89
- booster <- xgb.load(model_file )
90
- xgb.model.parameters(booster ) <- list (nthread = 2 )
91
- }
92
- predict(booster , newdata = pred_data )
93
- run_booster_check(booster , name )
130
+ is_skl <- grepl(" scikit" , model_file , fixed = TRUE )
131
+ if (is_skl ) {
132
+ return ()
94
133
}
134
+ booster <- xgb.load(model_file )
135
+ xgb.model.parameters(booster ) <- list (nthread = 2 )
136
+ predict(booster , newdata = pred_data )
137
+ run_booster_check(booster , model_file )
95
138
})
96
139
})
0 commit comments