Skip to content

Commit 753f8ba

Browse files
committed
added metric evaluation
1 parent b045a8a commit 753f8ba

File tree

4 files changed

+80
-34
lines changed

4 files changed

+80
-34
lines changed

src/shogun/io/openml/OpenMLData.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,8 @@ std::shared_ptr<CFeatures> OpenMLData::get_features(const std::string& label)
186186
copy_feat, result->get_feature_matrix().data(),
187187
m_feature_types.size() * m_cached_features.size());
188188

189-
result = std::make_shared<CDenseFeatures<float64_t>>(
189+
return std::make_shared<CDenseFeatures<float64_t>>(
190190
copy_feat, m_feature_types.size(), n_examples);
191-
192-
return result;
193191
}
194192

195193
std::shared_ptr<CLabels> OpenMLData::get_labels()

src/shogun/io/openml/OpenMLRun.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
* Authors: Gil Hoben
55
*/
66

7-
#include <shogun/evaluation/CrossValidationStorage.h>
87
#include <shogun/io/openml/OpenMLFile.h>
98
#include <shogun/io/openml/OpenMLRun.h>
109
#include <shogun/io/openml/ShogunOpenML.h>
@@ -47,7 +46,7 @@ std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
4746
SG_SERROR("INTERNAL ERROR: failed to cast model to machine!\n")
4847
}
4948

50-
auto* xval_storage = new CrossValidationStorage();
49+
auto xval_storage = std::make_shared<CrossValidationStorage>();
5150

5251
if (task->get_split()->contains_splits())
5352
{
@@ -64,11 +63,12 @@ std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
6463
for (auto fold_idx : range(task->get_num_fold()))
6564
{
6665
SGVector<index_t> train_i_idx(
67-
train_idx[repeat_idx][fold_idx].data(),
68-
train_idx[repeat_idx][fold_idx].size());
66+
train_idx[repeat_idx][fold_idx].begin(),
67+
train_idx[repeat_idx][fold_idx].end());
6968
SGVector<index_t> test_i_idx(
70-
train_idx[repeat_idx][fold_idx].data(),
71-
train_idx[repeat_idx][fold_idx].size());
69+
test_idx[repeat_idx][fold_idx].begin(),
70+
test_idx[repeat_idx][fold_idx].end());
71+
7272
xval_storage->append_fold_result(
7373
ShogunOpenML::run_model_on_fold(
7474
machine, task, features, labels, train_i_idx,
@@ -93,9 +93,7 @@ std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
9393
std::string{}, // setup_id
9494
std::string{}, // setup_string
9595
std::string{}, // parameter_settings
96-
std::vector<float64_t>{}, // evaluations
97-
std::vector<float64_t>{}, // fold_evaluations
98-
std::vector<float64_t>{}, // sample_evaluations
96+
xval_storage, // xval_storage
9997
std::string{}, // data_content
10098
std::vector<std::string>{}, // output_files
10199
task, // task
@@ -119,7 +117,10 @@ void OpenMLRun::to_filesystem(const std::string& directory) const
119117
SG_SNOTIMPLEMENTED
120118
}
121119

122-
void OpenMLRun::publish() const
120+
void OpenMLRun::publish() const {SG_SNOTIMPLEMENTED}
121+
122+
std::unique_ptr<std::ostream> OpenMLRun::to_xml() const
123123
{
124-
SG_SNOTIMPLEMENTED
124+
125+
return std::unique_ptr<std::ostream>();
125126
}

src/shogun/io/openml/OpenMLRun.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define SHOGUN_OPENMLRUN_H
99

1010
#include <shogun/base/SGObject.h>
11+
#include <shogun/evaluation/CrossValidationStorage.h>
1112

1213
#include <shogun/io/openml/OpenMLFlow.h>
1314
#include <shogun/io/openml/OpenMLTask.h>
@@ -20,9 +21,7 @@ namespace shogun {
2021
const std::string& uploader, const std::string& uploader_name,
2122
const std::string& setup_id, const std::string& setup_string,
2223
const std::string& parameter_settings,
23-
std::vector<float64_t> evaluations,
24-
std::vector<float64_t> fold_evaluations,
25-
std::vector<float64_t> sample_evaluations,
24+
std::shared_ptr<CrossValidationStorage> xval_storage,
2625
const std::string& data_content,
2726
std::vector<std::string> output_files,
2827
std::shared_ptr<OpenMLTask> task, std::shared_ptr<OpenMLFlow> flow,
@@ -31,9 +30,7 @@ namespace shogun {
3130
: m_uploader(uploader), m_uploader_name(uploader_name),
3231
m_setup_id(setup_id), m_setup_string(setup_string),
3332
m_parameter_settings(parameter_settings),
34-
m_evaluations(std::move(evaluations)),
35-
m_fold_evaluations(std::move(fold_evaluations)),
36-
m_sample_evaluations(std::move(sample_evaluations)),
33+
m_xval_storage(xval_storage),
3734
m_data_content(data_content),
3835
m_output_files(std::move(output_files)), m_task(std::move(task)),
3936
m_flow(std::move(flow)), m_run_id(run_id),
@@ -55,6 +52,8 @@ namespace shogun {
5552

5653
void to_filesystem(const std::string& directory) const;
5754

55+
std::unique_ptr<std::ostream> to_xml() const;
56+
5857
void publish() const;
5958

6059
private:
@@ -63,9 +62,7 @@ namespace shogun {
6362
std::string m_setup_id;
6463
std::string m_setup_string;
6564
std::string m_parameter_settings;
66-
std::vector<float64_t> m_evaluations;
67-
std::vector<float64_t> m_fold_evaluations;
68-
std::vector<float64_t> m_sample_evaluations;
65+
std::shared_ptr<CrossValidationStorage> m_xval_storage;
6966
std::string m_data_content;
7067
std::vector<std::string> m_output_files;
7168
std::shared_ptr<OpenMLTask> m_task;

src/shogun/io/openml/ShogunOpenML.cpp

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* Authors: Gil Hoben
77
*/
88

9+
#include <shogun/evaluation/ContingencyTableEvaluation.h>
10+
#include <shogun/evaluation/MeanAbsoluteError.h>
911
#include <shogun/util/factory.h>
1012

1113
#include <shogun/io/openml/ShogunOpenML.h>
@@ -310,6 +312,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
310312
{
311313
auto task_type = task->get_task_type();
312314

315+
CEvaluation* evaluation_criterion = nullptr;
316+
317+
switch (task_type)
318+
{
319+
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
320+
evaluation_criterion = new CAccuracyMeasure();
321+
break;
322+
case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
323+
evaluation_criterion = new CMeanAbsoluteError();
324+
break;
325+
case OpenMLTask::TaskType::LEARNING_CURVE:
326+
SG_SNOTIMPLEMENTED
327+
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
328+
SG_SNOTIMPLEMENTED
329+
case OpenMLTask::TaskType::CLUSTERING:
330+
SG_SNOTIMPLEMENTED
331+
case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
332+
SG_SNOTIMPLEMENTED
333+
case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
334+
SG_SNOTIMPLEMENTED
335+
case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
336+
SG_SNOTIMPLEMENTED
337+
}
338+
313339
switch (task_type)
314340
{
315341
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
@@ -324,8 +350,6 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
324350
// shared
325351
auto* features_clone = features->clone()->as<CFeatures>();
326352
auto* labels_clone = labels->clone()->as<CLabels>();
327-
// auto* evaluation_criterion =
328-
// (CEvaluation*)m_evaluation_criterion->clone();
329353

330354
/* evtl. update xvalidation output class */
331355
fold->set_run_index(repeat_idx);
@@ -371,8 +395,10 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
371395
SG_REF(result_labels);
372396

373397
/* evaluate */
374-
// results[i] = evaluation_criterion->evaluate(result_labels, labels);
375-
// SG_DEBUG("result on fold %d is %f\n", i, results[i])
398+
auto result =
399+
evaluation_criterion->evaluate(result_labels, labels_clone);
400+
SG_SINFO(
401+
"result on repeat %d fold %d is %f\n", repeat_idx, fold_idx, result)
376402

377403
/* evtl. update xvalidation output class */
378404
fold->set_test_indices(test_idx);
@@ -381,18 +407,17 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
381407
fold->set_test_true_result(true_labels);
382408
SG_UNREF(true_labels)
383409
fold->post_update_results();
384-
// fold->set_evaluation_result(results[i]);
410+
fold->set_evaluation_result(result);
385411

386412
/* clean up, remove subsets */
387413
labels->remove_subset();
388414
SG_UNREF(cloned_machine);
389415
SG_UNREF(features_clone);
390416
SG_UNREF(labels_clone);
391-
// SG_UNREF(evaluation_criterion);
392417
SG_UNREF(result_labels);
418+
delete evaluation_criterion;
393419
return fold;
394420
}
395-
break;
396421
case OpenMLTask::TaskType::LEARNING_CURVE:
397422
SG_SNOTIMPLEMENTED
398423
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
@@ -417,6 +442,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
417442
{
418443
auto task_type = task->get_task_type();
419444

445+
CEvaluation* evaluation_criterion = nullptr;
446+
447+
switch (task_type)
448+
{
449+
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
450+
evaluation_criterion = new CAccuracyMeasure();
451+
break;
452+
case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
453+
evaluation_criterion = new CMeanAbsoluteError();
454+
break;
455+
case OpenMLTask::TaskType::LEARNING_CURVE:
456+
SG_SNOTIMPLEMENTED
457+
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
458+
SG_SNOTIMPLEMENTED
459+
case OpenMLTask::TaskType::CLUSTERING:
460+
SG_SNOTIMPLEMENTED
461+
case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
462+
SG_SNOTIMPLEMENTED
463+
case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
464+
SG_SNOTIMPLEMENTED
465+
case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
466+
SG_SNOTIMPLEMENTED
467+
}
468+
420469
switch (task_type)
421470
{
422471
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
@@ -446,23 +495,24 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
446495
SG_SDEBUG("finished evaluation\n")
447496

448497
/* evaluate */
449-
// results[i] = evaluation_criterion->evaluate(result_labels, labels);
450-
// SG_DEBUG("result on fold %d is %f\n", i, results[i])
498+
auto result =
499+
evaluation_criterion->evaluate(result_labels, labels_clone);
500+
SG_SINFO("result is %f\n", result)
451501

452502
/* evtl. update xvalidation output class */
453503
fold->set_test_result(result_labels);
454504
auto* true_labels = (CLabels*)labels->clone();
455505
fold->set_test_true_result(true_labels);
456506
SG_UNREF(true_labels)
457507
fold->post_update_results();
458-
// fold->set_evaluation_result(results[i]);
508+
fold->set_evaluation_result(result);
459509

460510
// cleanup
461511
SG_UNREF(cloned_machine);
462512
SG_UNREF(features_clone);
463513
SG_UNREF(labels_clone);
464-
// SG_UNREF(evaluation_criterion);
465514
SG_UNREF(result_labels);
515+
delete evaluation_criterion;
466516
return fold;
467517
}
468518
case OpenMLTask::TaskType::LEARNING_CURVE:

0 commit comments

Comments
 (0)