11context(" Models from previous versions of XGBoost can be loaded" )
22
33metadata <- list (
4- kRounds = 2 ,
4+ kRounds = 4 ,
55 kRows = 1000 ,
66 kCols = 4 ,
77 kForests = 2 ,
@@ -10,87 +10,130 @@ metadata <- list(
1010)
1111
1212run_model_param_check <- function (config ) {
13- testthat :: expect_equal(config $ learner $ learner_model_param $ num_feature , ' 4' )
14- testthat :: expect_equal(config $ learner $ learner_train_param $ booster , ' gbtree' )
13+ testthat :: expect_equal(config $ learner $ learner_model_param $ num_feature , " 4" )
14+ testthat :: expect_equal(config $ learner $ learner_train_param $ booster , " gbtree" )
15+ }
16+
17+ get_n_rounds <- function (model_file ) {
18+ is_10 <- grepl(" 1.0.0rc1" , model_file , fixed = TRUE )
19+ if (is_10 ) {
20+ 2
21+ } else {
22+ metadata $ kRounds
23+ }
1524}
1625
1726get_num_tree <- function (booster ) {
1827 dump <- xgb.dump(booster )
19- m <- regexec(' booster\\ [[0-9]+\\ ]' , dump , perl = TRUE )
28+ m <- regexec(" booster\\ [[0-9]+\\ ]" , dump , perl = TRUE )
2029 m <- regmatches(dump , m )
21- num_tree <- Reduce(' + ' , lapply(m , length ))
22- return ( num_tree )
30+ num_tree <- Reduce(" + " , lapply(m , length ))
31+ num_tree
2332}
2433
25- run_booster_check <- function (booster , name ) {
34+ run_booster_check <- function (booster , model_file ) {
2635 config <- xgb.config(booster )
2736 run_model_param_check(config )
28- if (name == ' cls' ) {
29- testthat :: expect_equal(get_num_tree(booster ),
30- metadata $ kForests * metadata $ kRounds * metadata $ kClasses )
31- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5 )
32- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' multi:softmax' )
33- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ),
34- metadata $ kClasses )
35- } else if (name == ' logitraw' ) {
36- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
37- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ), 0 )
38- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' binary:logitraw' )
39- } else if (name == ' logit' ) {
40- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
41- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ num_class ), 0 )
42- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' binary:logistic' )
43- } else if (name == ' ltr' ) {
44- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
45- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' rank:ndcg' )
37+ is_model <- function (typ ) {
38+ grepl(typ , model_file , fixed = TRUE )
39+ }
40+ n_rounds <- get_n_rounds(model_file = model_file )
41+ if (is_model(" cls" )) {
42+ testthat :: expect_equal(
43+ get_num_tree(booster ), metadata $ kForests * n_rounds * metadata $ kClasses
44+ )
45+ testthat :: expect_equal(
46+ as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5
47+ )
48+ testthat :: expect_equal(
49+ config $ learner $ learner_train_param $ objective , " multi:softmax"
50+ )
51+ testthat :: expect_equal(
52+ as.numeric(config $ learner $ learner_model_param $ num_class ),
53+ metadata $ kClasses
54+ )
55+ } else if (is_model(" logitraw" )) {
56+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
57+ testthat :: expect_equal(
58+ as.numeric(config $ learner $ learner_model_param $ num_class ), 0
59+ )
60+ testthat :: expect_equal(
61+ config $ learner $ learner_train_param $ objective , " binary:logitraw"
62+ )
63+ } else if (is_model(" logit" )) {
64+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
65+ testthat :: expect_equal(
66+ as.numeric(config $ learner $ learner_model_param $ num_class ), 0
67+ )
68+ testthat :: expect_equal(
69+ config $ learner $ learner_train_param $ objective , " binary:logistic"
70+ )
71+ } else if (is_model(" ltr" )) {
72+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
73+ testthat :: expect_equal(
74+ config $ learner $ learner_train_param $ objective , " rank:ndcg"
75+ )
76+ } else if (is_model(" aft" )) {
77+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
78+ testthat :: expect_equal(
79+ config $ learner $ learner_train_param $ objective , " survival:aft"
80+ )
4681 } else {
47- testthat :: expect_equal(name , ' reg' )
48- testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * metadata $ kRounds )
49- testthat :: expect_equal(as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5 )
50- testthat :: expect_equal(config $ learner $ learner_train_param $ objective , ' reg:squarederror' )
82+ testthat :: expect_true(is_model(" reg" ))
83+ testthat :: expect_equal(get_num_tree(booster ), metadata $ kForests * n_rounds )
84+ testthat :: expect_equal(
85+ as.numeric(config $ learner $ learner_model_param $ base_score ), 0.5
86+ )
87+ testthat :: expect_equal(
88+ config $ learner $ learner_train_param $ objective , " reg:squarederror"
89+ )
5190 }
5291}
5392
5493test_that(" Models from previous versions of XGBoost can be loaded" , {
55- bucket <- ' xgboost-ci-jenkins-artifacts'
56- region <- ' us-west-2'
57- file_name <- ' xgboost_r_model_compatibility_test. zip'
94+ bucket <- " xgboost-ci-jenkins-artifacts"
95+ region <- " us-west-2"
96+ file_name <- " xgboost_model_compatibility_tests-3.0.2. zip"
5897 zipfile <- tempfile(fileext = " .zip" )
5998 extract_dir <- tempdir()
60- download.file(paste(' https://' , bucket , ' .s3-' , region , ' .amazonaws.com/' , file_name , sep = ' ' ),
61- destfile = zipfile , mode = ' wb' , quiet = TRUE )
99+ result <- tryCatch(
100+ {
101+ download.file(
102+ paste(
103+ " https://" , bucket , " .s3-" , region , " .amazonaws.com/" , file_name ,
104+ sep = " "
105+ ),
106+ destfile = zipfile , mode = " wb" , quiet = TRUE
107+ )
108+ zipfile
109+ },
110+ error = function (e ) {
111+ print(e )
112+ NA_character_
113+ }
114+ )
115+ if (is.na(result )) {
116+ print(" Failed to download old models." )
117+ return ()
118+ }
119+
62120 unzip(zipfile , exdir = extract_dir , overwrite = TRUE )
63- model_dir <- file.path(extract_dir , ' models' )
121+ model_dir <- file.path(extract_dir , " models" )
64122
65- pred_data <- xgb.DMatrix(matrix (c(0 , 0 , 0 , 0 ), nrow = 1 , ncol = 4 ), nthread = 2 )
123+ pred_data <- xgb.DMatrix(
124+ matrix (c(0 , 0 , 0 , 0 ), nrow = 1 , ncol = 4 ),
125+ nthread = 2
126+ )
66127
67128 lapply(list.files(model_dir ), function (x ) {
68129 model_file <- file.path(model_dir , x )
69- m <- regexec(" xgboost-([0-9\\ .]+)\\ .([a-z]+)\\ .[a-z]+" , model_file , perl = TRUE )
70- m <- regmatches(model_file , m )[[1 ]]
71- model_xgb_ver <- m [2 ]
72- name <- m [3 ]
73- is_rds <- endsWith(model_file , ' .rds' )
74- is_json <- endsWith(model_file , ' .json' )
75- # TODO: update this test for new RDS format
76- if (is_rds ) {
77- return (NULL )
78- }
79- # Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x
80- if (is_rds && compareVersion(model_xgb_ver , ' 1.1.1.1' ) < 0 ) {
81- booster <- readRDS(model_file )
82- expect_warning(predict(booster , newdata = pred_data ))
83- booster <- readRDS(model_file )
84- expect_warning(run_booster_check(booster , name ))
85- } else {
86- if (is_rds ) {
87- booster <- readRDS(model_file )
88- } else {
89- booster <- xgb.load(model_file )
90- xgb.model.parameters(booster ) <- list (nthread = 2 )
91- }
92- predict(booster , newdata = pred_data )
93- run_booster_check(booster , name )
130+ is_skl <- grepl(" scikit" , model_file , fixed = TRUE )
131+ if (is_skl ) {
132+ return ()
94133 }
134+ booster <- xgb.load(model_file )
135+ xgb.model.parameters(booster ) <- list (nthread = 2 )
136+ predict(booster , newdata = pred_data )
137+ run_booster_check(booster , model_file )
95138 })
96139})
0 commit comments