Skip to content

Commit 4c988bb

Browse files
committed
started work on splits
1 parent bbb493a commit 4c988bb

File tree

2 files changed

+309
-88
lines changed

2 files changed

+309
-88
lines changed

src/shogun/io/OpenMLFlow.cpp

+175-66
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
4141
/* OpenML server format */
4242
const char* OpenMLReader::xml_server = "https://www.openml.org/api/v1/xml";
4343
const char* OpenMLReader::json_server = "https://www.openml.org/api/v1/json";
44+
const char* OpenMLReader::splits_server = "https://www.openml.org/api_splits";
45+
4446
/* DATA API */
4547
const char* OpenMLReader::dataset_description = "/data/{}";
4648
const char* OpenMLReader::list_data_qualities = "/data/qualities/list";
@@ -52,10 +54,13 @@ const char* OpenMLReader::list_dataset_filter = "/data/list/{}";
5254
const char* OpenMLReader::flow_file = "/flow/{}";
5355
/* TASK API */
5456
const char* OpenMLReader::task_file = "/task/{}";
57+
/* SPLIT API */
58+
const char* OpenMLReader::get_split = "/split/{}";
5559

5660
const std::unordered_map<std::string, std::string>
5761
OpenMLReader::m_format_options = {{"xml", xml_server},
58-
{"json", json_server}};
62+
{"json", json_server},
63+
{"split", splits_server}};
5964
const std::unordered_map<std::string, std::string>
6065
OpenMLReader::m_request_options = {
6166
{"dataset_description", dataset_description},
@@ -104,8 +109,6 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
104109
}
105110
}
106111

107-
#endif // HAVE_CURL
108-
109112
/**
110113
* Checks the returned response from OpenML in JSON format
111114
* @param doc the parsed OpenML JSON format response
@@ -367,19 +370,25 @@ OpenMLData::get_data(const std::string& id, const std::string& api_key)
367370
"md5_checksum", dataset_description.GetObject());
368371

369372
// features
370-
std::vector<std::unordered_map<std::string, std::string>> param_vector;
373+
std::vector<std::unordered_map<std::string, std::vector<std::string>>>
374+
param_vector;
371375
return_string = reader.get("data_features", "json", id);
372376
document.Parse(return_string.c_str());
373377
check_response(document, "data_features");
374378
const Value& dataset_features = document["data_features"];
375-
for (const auto& param : dataset_features.GetArray())
379+
for (const auto& param : dataset_features["feature"].GetArray())
376380
{
377-
std::unordered_map<std::string, std::string> param_map;
381+
std::unordered_map<std::string, std::vector<std::string>> param_map;
378382
for (const auto& param_descriptors : param.GetObject())
379383
{
380-
param_map.emplace(
381-
param_descriptors.name.GetString(),
382-
param_descriptors.value.GetString());
384+
std::vector<std::string> second;
385+
if (param_descriptors.value.IsArray())
386+
for (const auto& v : param_descriptors.value.GetArray())
387+
second.emplace_back(v.GetString());
388+
else
389+
second.emplace_back(param_descriptors.value.GetString());
390+
391+
param_map.emplace(param_descriptors.name.GetString(), second);
383392
}
384393
param_vector.push_back(param_map);
385394
}
@@ -390,14 +399,17 @@ OpenMLData::get_data(const std::string& id, const std::string& api_key)
390399
document.Parse(return_string.c_str());
391400
check_response(document, "data_qualities");
392401
const Value& data_qualities = document["data_qualities"];
393-
for (const auto& param : data_qualities.GetArray())
402+
for (const auto& param : data_qualities["quality"].GetArray())
394403
{
395404
std::unordered_map<std::string, std::string> param_map;
396405
for (const auto& param_quality : param.GetObject())
397406
{
398-
param_map.emplace(
399-
param_quality.name.GetString(),
400-
param_quality.value.GetString());
407+
if (param_quality.name.IsString() && param_quality.value.IsString())
408+
param_map.emplace(
409+
param_quality.name.GetString(),
410+
param_quality.value.GetString());
411+
else if (param_quality.name.IsString())
412+
param_map.emplace(param_quality.name.GetString(), "");
401413
}
402414
qualities_vector.push_back(param_map);
403415
}
@@ -418,16 +430,28 @@ std::string OpenMLData::get_data_buffer(const std::string& api_key)
418430
return nullptr;
419431
}
420432

433+
std::shared_ptr<OpenMLSplit>
434+
OpenMLSplit::get_split(const std::string& split_url, const std::string& api_key)
435+
{
436+
Document document;
437+
438+
auto reader = OpenMLReader(api_key);
439+
auto return_string = reader.get("get_split", "split", split_url);
440+
auto return_stream = std::istringstream(return_string);
441+
// add ARFF parsing here
442+
SG_SNOTIMPLEMENTED
443+
return nullptr;
444+
}
445+
421446
std::shared_ptr<OpenMLTask>
422447
OpenMLTask::get_task(const std::string& task_id, const std::string& api_key)
423448
{
424449
Document document;
425450
std::string task_name;
426451
std::string task_type_id;
427-
std::shared_ptr<OpenMLData> openml_dataset;
428-
std::shared_ptr<OpenMLSplit> openml_split;
429-
std::pair<std::shared_ptr<OpenMLData>, std::shared_ptr<OpenMLSplit>>
430-
task_descriptor;
452+
std::shared_ptr<OpenMLData> openml_dataset = nullptr;
453+
std::shared_ptr<OpenMLSplit> openml_split = nullptr;
454+
std::unordered_map<std::string, std::string> evaluation_measures;
431455

432456
auto reader = OpenMLReader(api_key);
433457
auto return_string = reader.get("task_file", "json", task_id);
@@ -451,63 +475,62 @@ OpenMLTask::get_task(const std::string& task_id, const std::string& api_key)
451475
// expect two elements in input array: dataset and split
452476
const Value& json_input = root["input"];
453477

454-
REQUIRE(
455-
json_input.IsArray(), "Currently the dataset reader can only handle "
456-
"inputs with a dataset and split field.\n")
457-
458478
auto input_array = json_input.GetArray();
459-
REQUIRE(
460-
input_array.Size() == 2,
461-
"Currently the dataset reader can only handle inputs with a dataset "
462-
"and split fields. Found %d elements.\n",
463-
input_array.Size())
464-
465-
// handle dataset
466-
auto json_dataset = input_array[0].GetObject();
467479

468-
if (strcmp(json_dataset["name"].GetString(), "source_data") == 0)
480+
for (const auto& task_settings : input_array)
469481
{
470-
auto dataset_info = json_dataset["data_set"].GetObject();
471-
std::string dataset_id = dataset_info["data_set_id"].GetString();
472-
std::string target_feature = dataset_info["target_feature"].GetString();
473-
// openml_dataset =
474-
// std::make_shared<OpenMLData>(dataset_id, target_feature);
475-
}
476-
else
477-
SG_SERROR("Error parsing the OpenML dataset, could not find the "
478-
"source_data field.\n")
479-
480-
// handle split
481-
auto json_split = input_array[1].GetObject();
482-
if (strcmp(json_split["name"].GetString(), "estimation_procedure") == 0)
483-
{
484-
auto split_info = json_dataset["estimation_procedure"].GetObject();
485-
std::string split_id = split_info["id"].GetString();
486-
std::string split_type = split_info["type"].GetString();
487-
std::string split_url = split_info["data_splits_url"].GetString();
488-
std::unordered_map<std::string, std::string> split_parameters;
489-
for (const auto& param : split_info["parameter"].GetArray())
482+
if (strcmp(task_settings["name"].GetString(), "source_data") == 0)
490483
{
491-
if (param.Size() == 2)
492-
split_parameters.emplace(
493-
param["name"].GetString(), param["value"].GetString());
494-
else if (param.Size() == 1)
495-
split_parameters.emplace(param["name"].GetString(), "");
496-
else
497-
SG_SERROR("Unexpected number of parameters in parameter array "
498-
"of estimation_procedure.\n")
484+
auto dataset_info = task_settings["data_set"].GetObject();
485+
std::string dataset_id = dataset_info["data_set_id"].GetString();
486+
std::string target_feature =
487+
dataset_info["target_feature"].GetString();
488+
openml_dataset = OpenMLData::get_data(dataset_id, api_key);
489+
}
490+
else if (
491+
strcmp(task_settings["name"].GetString(), "estimation_procedure") ==
492+
0)
493+
{
494+
auto split_info = task_settings["estimation_procedure"].GetObject();
495+
std::string split_id = split_info["id"].GetString();
496+
std::string split_type = split_info["type"].GetString();
497+
std::string split_url = split_info["data_splits_url"].GetString();
498+
std::unordered_map<std::string, std::string> split_parameters;
499+
for (const auto& param : split_info["parameter"].GetArray())
500+
{
501+
if (param.HasMember("name") && param.HasMember("value"))
502+
split_parameters.emplace(
503+
param["name"].GetString(), param["value"].GetString());
504+
else if (param.HasMember("name"))
505+
split_parameters.emplace(param["name"].GetString(), "");
506+
else
507+
SG_SERROR(
508+
"Unexpected number of parameters in parameter array "
509+
"of estimation_procedure.\n")
510+
}
511+
openml_split = std::make_shared<OpenMLSplit>(
512+
split_id, split_type, split_url, split_parameters);
513+
}
514+
else if (
515+
strcmp(task_settings["name"].GetString(), "evaluation_measures") ==
516+
0)
517+
{
518+
auto evaluation_info =
519+
task_settings["evaluation_measures"].GetObject();
520+
for (const auto& param : evaluation_info)
521+
{
522+
evaluation_measures.emplace(
523+
param.name.GetString(), param.value.GetString());
524+
}
499525
}
500-
openml_split = std::make_shared<OpenMLSplit>(
501-
split_id, split_type, split_url, split_parameters);
502526
}
503-
else
504-
SG_SERROR("Error parsing the OpenML dataset, could not find the "
505-
"estimation_procedure field.\n")
506527

507-
task_descriptor = std::make_pair(openml_dataset, openml_split);
528+
if (openml_dataset == nullptr && openml_split == nullptr)
529+
SG_SERROR("Error parsing task.")
508530

509531
auto result = std::make_shared<OpenMLTask>(
510-
task_id, task_name, task_type, task_type_id, task_descriptor);
532+
task_id, task_name, task_type, task_type_id, evaluation_measures,
533+
openml_split, openml_dataset);
511534

512535
return result;
513536
}
@@ -517,7 +540,19 @@ OpenMLTask::get_task_from_string(const std::string& task_type)
517540
{
518541
if (task_type == "Supervised Classification")
519542
return OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION;
520-
SG_SERROR("OpenMLTask does not supported \"%s\"", task_type.c_str())
543+
SG_SERROR("OpenMLTask does not support \"%s\"", task_type.c_str())
544+
}
545+
546+
SGMatrix<int32_t> OpenMLTask::get_train_indices()
547+
{
548+
SG_SNOTIMPLEMENTED
549+
return SGMatrix<int32_t>();
550+
}
551+
552+
SGMatrix<int32_t> OpenMLTask::get_test_indices()
553+
{
554+
SG_SNOTIMPLEMENTED
555+
return SGMatrix<int32_t>();
521556
}
522557

523558
/**
@@ -802,3 +837,77 @@ ShogunOpenML::get_class_info(const std::string& class_name)
802837

803838
return result;
804839
}
840+
841+
CLabels* ShogunOpenML::run_model_on_fold(
842+
const std::shared_ptr<CSGObject>& model,
843+
const std::shared_ptr<OpenMLTask>& task, CFeatures* X_train,
844+
index_t repeat_number, index_t fold_number, CLabels* y_train,
845+
CFeatures* X_test)
846+
{
847+
auto task_type = task->get_task_type();
848+
auto model_clone = std::shared_ptr<CSGObject>(model->clone());
849+
850+
switch (task_type)
851+
{
852+
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
853+
case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
854+
{
855+
if (auto machine = std::dynamic_pointer_cast<CMachine>(model_clone))
856+
{
857+
machine->put("labels", y_train);
858+
machine->train(X_train);
859+
return machine->apply(X_test);
860+
}
861+
else
862+
SG_SERROR("The provided model is not trainable!\n")
863+
}
864+
break;
865+
case OpenMLTask::TaskType::LEARNING_CURVE:
866+
SG_SNOTIMPLEMENTED
867+
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
868+
SG_SNOTIMPLEMENTED
869+
case OpenMLTask::TaskType::CLUSTERING:
870+
SG_SNOTIMPLEMENTED
871+
case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
872+
SG_SNOTIMPLEMENTED
873+
case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
874+
SG_SNOTIMPLEMENTED
875+
case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
876+
SG_SNOTIMPLEMENTED
877+
}
878+
return nullptr;
879+
}
880+
881+
std::shared_ptr<OpenMLRun> OpenMLRun::run_model_on_task(
882+
std::shared_ptr<CSGObject> model, std::shared_ptr<OpenMLTask> task)
883+
{
884+
SG_SNOTIMPLEMENTED
885+
return std::shared_ptr<OpenMLRun>();
886+
}
887+
888+
std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
889+
std::shared_ptr<OpenMLFlow> flow, std::shared_ptr<OpenMLTask> task)
890+
{
891+
auto data = task->get_dataset();
892+
SG_SNOTIMPLEMENTED
893+
return std::shared_ptr<OpenMLRun>();
894+
}
895+
896+
std::shared_ptr<OpenMLRun>
897+
OpenMLRun::from_filesystem(const std::string& directory)
898+
{
899+
SG_SNOTIMPLEMENTED
900+
return nullptr;
901+
}
902+
903+
void OpenMLRun::to_filesystem(const std::string& directory) const
904+
{
905+
SG_SNOTIMPLEMENTED
906+
}
907+
908+
void OpenMLRun::publish() const
909+
{
910+
SG_SNOTIMPLEMENTED
911+
}
912+
913+
#endif // HAVE_CURL

0 commit comments

Comments
 (0)