66 * Authors: Gil Hoben
77 */
88
9+ #include < shogun/evaluation/ContingencyTableEvaluation.h>
10+ #include < shogun/evaluation/MeanAbsoluteError.h>
911#include < shogun/util/factory.h>
1012
1113#include < shogun/io/openml/ShogunOpenML.h>
@@ -310,6 +312,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
310312{
311313 auto task_type = task->get_task_type ();
312314
315+ CEvaluation* evaluation_criterion = nullptr ;
316+
317+ switch (task_type)
318+ {
319+ case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
320+ evaluation_criterion = new CAccuracyMeasure ();
321+ break ;
322+ case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
323+ evaluation_criterion = new CMeanAbsoluteError ();
324+ break ;
325+ case OpenMLTask::TaskType::LEARNING_CURVE:
326+ SG_SNOTIMPLEMENTED
327+ case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
328+ SG_SNOTIMPLEMENTED
329+ case OpenMLTask::TaskType::CLUSTERING:
330+ SG_SNOTIMPLEMENTED
331+ case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
332+ SG_SNOTIMPLEMENTED
333+ case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
334+ SG_SNOTIMPLEMENTED
335+ case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
336+ SG_SNOTIMPLEMENTED
337+ }
338+
313339 switch (task_type)
314340 {
315341 case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
@@ -324,8 +350,6 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
324350 // shared
325351 auto * features_clone = features->clone ()->as <CFeatures>();
326352 auto * labels_clone = labels->clone ()->as <CLabels>();
327- // auto* evaluation_criterion =
328- // (CEvaluation*)m_evaluation_criterion->clone();
329353
330354 /* evtl. update xvalidation output class */
331355 fold->set_run_index (repeat_idx);
@@ -371,8 +395,10 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
371395 SG_REF (result_labels);
372396
373397 /* evaluate */
374- // results[i] = evaluation_criterion->evaluate(result_labels, labels);
375- // SG_DEBUG("result on fold %d is %f\n", i, results[i])
398+ auto result =
399+ evaluation_criterion->evaluate (result_labels, labels_clone);
400+ SG_SINFO (
401+ " result on repeat %d fold %d is %f\n " , repeat_idx, fold_idx, result)
376402
377403 /* evtl. update xvalidation output class */
378404 fold->set_test_indices (test_idx);
@@ -381,18 +407,17 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
381407 fold->set_test_true_result (true_labels);
382408 SG_UNREF (true_labels)
383409 fold->post_update_results ();
384- // fold->set_evaluation_result(results[i] );
410+ fold->set_evaluation_result (result );
385411
386412 /* clean up, remove subsets */
387413 labels->remove_subset ();
388414 SG_UNREF (cloned_machine);
389415 SG_UNREF (features_clone);
390416 SG_UNREF (labels_clone);
391- // SG_UNREF(evaluation_criterion);
392417 SG_UNREF (result_labels);
418+ delete evaluation_criterion;
393419 return fold;
394420 }
395- break ;
396421 case OpenMLTask::TaskType::LEARNING_CURVE:
397422 SG_SNOTIMPLEMENTED
398423 case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
@@ -417,6 +442,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
417442{
418443 auto task_type = task->get_task_type ();
419444
445+ CEvaluation* evaluation_criterion = nullptr ;
446+
447+ switch (task_type)
448+ {
449+ case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
450+ evaluation_criterion = new CAccuracyMeasure ();
451+ break ;
452+ case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
453+ evaluation_criterion = new CMeanAbsoluteError ();
454+ break ;
455+ case OpenMLTask::TaskType::LEARNING_CURVE:
456+ SG_SNOTIMPLEMENTED
457+ case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
458+ SG_SNOTIMPLEMENTED
459+ case OpenMLTask::TaskType::CLUSTERING:
460+ SG_SNOTIMPLEMENTED
461+ case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
462+ SG_SNOTIMPLEMENTED
463+ case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
464+ SG_SNOTIMPLEMENTED
465+ case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
466+ SG_SNOTIMPLEMENTED
467+ }
468+
420469 switch (task_type)
421470 {
422471 case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
@@ -446,23 +495,24 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
446495 SG_SDEBUG (" finished evaluation\n " )
447496
448497 /* evaluate */
449- // results[i] = evaluation_criterion->evaluate(result_labels, labels);
450- // SG_DEBUG("result on fold %d is %f\n", i, results[i])
498+ auto result =
499+ evaluation_criterion->evaluate (result_labels, labels_clone);
500+ SG_SINFO (" result is %f\n " , result)
451501
452502 /* evtl. update xvalidation output class */
453503 fold->set_test_result (result_labels);
454504 auto * true_labels = (CLabels*)labels->clone ();
455505 fold->set_test_true_result (true_labels);
456506 SG_UNREF (true_labels)
457507 fold->post_update_results ();
458- // fold->set_evaluation_result(results[i] );
508+ fold->set_evaluation_result (result );
459509
460510 // cleanup
461511 SG_UNREF (cloned_machine);
462512 SG_UNREF (features_clone);
463513 SG_UNREF (labels_clone);
464- // SG_UNREF(evaluation_criterion);
465514 SG_UNREF (result_labels);
515+ delete evaluation_criterion;
466516 return fold;
467517 }
468518 case OpenMLTask::TaskType::LEARNING_CURVE:
0 commit comments