Skip to content

Commit 45ac04e

Browse files
committed
initial ShogunOpenML class
1 parent b70398d commit 45ac04e

File tree

4 files changed

+342
-83
lines changed

4 files changed

+342
-83
lines changed

src/interfaces/swig/IO.i

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
%rename(MemoryMappedFile) CMemoryMappedFile;
2727

2828
%shared_ptr(shogun::OpenMLFlow)
29+
%shared_ptr(shogun::ShogunOpenML::flow_to_model)
30+
%shared_ptr(shogun::ShogunOpenML::model_to_flow)
2931

3032
%include <shogun/io/File.h>
3133
%include <shogun/io/streaming/StreamingFile.h>

src/shogun/base/SGObject.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1112,5 +1112,4 @@ std::string CSGObject::string_enum_reverse_lookup(
11121112
return p.second == enum_value;
11131113
});
11141114
return enum_map_it->first;
1115-
11161115
}

src/shogun/io/OpenMLFlow.cpp

+242-45
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
using namespace shogun;
1616
using namespace rapidjson;
1717

18+
/**
19+
* The writer callback function used to write the packets to a C++ string.
20+
* @param data the data received in CURL request
21+
* @param size always 1
22+
* @param nmemb the size of data
23+
* @param buffer_in the buffer to write to
24+
* @return the size of buffer that was written
25+
*/
1826
size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
1927
{
2028
// adapted from https://stackoverflow.com/a/5780603
@@ -30,13 +38,16 @@ size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
3038
return 0;
3139
}
3240

41+
/* OpenML server format */
3342
const char* OpenMLReader::xml_server = "https://www.openml.org/api/v1/xml";
3443
const char* OpenMLReader::json_server = "https://www.openml.org/api/v1/json";
44+
/* DATA API */
3545
const char* OpenMLReader::dataset_description = "/data/{}";
3646
const char* OpenMLReader::list_data_qualities = "/data/qualities/list";
3747
const char* OpenMLReader::data_features = "/data/features/{}";
3848
const char* OpenMLReader::list_dataset_qualities = "/data/qualities/{}";
3949
const char* OpenMLReader::list_dataset_filter = "/data/list/{}";
50+
/* FLOW API */
4051
const char* OpenMLReader::flow_file = "/flow/{}";
4152

4253
const std::unordered_map<std::string, std::string>
@@ -84,25 +95,16 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
8495
if (code != CURLE_OK)
8596
{
8697
// TODO: call curl_easy_cleanup(curl_handle) ?
87-
SG_SERROR("Curl error: %s\n", curl_easy_strerror(code))
98+
SG_SERROR("Connection error: %s.\n", curl_easy_strerror(code))
8899
}
89-
// else
90-
// {
91-
// long response_code;
92-
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
93-
//&response_code); if (response_code == 200) return;
94-
// else
95-
// {
96-
// if (response_code == 181)
97-
// SG_SERROR("Unknown flow. The flow with the given ID was not
98-
// found in the database.") else if (response_code == 180)
99-
// SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
100-
// }
101-
// }
102100
}
103101

104102
#endif // HAVE_CURL
105103

104+
/**
105+
* Checks the returned flow in JSON format
106+
* @param doc the parsed flow
107+
*/
106108
static void check_flow_response(rapidjson::Document& doc)
107109
{
108110
if (SG_UNLIKELY(doc.HasMember("error")))
@@ -116,24 +118,36 @@ static void check_flow_response(rapidjson::Document& doc)
116118
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
117119
}
118120

121+
/**
122+
* Helper function to add JSON objects as string in map
123+
* @param v a RapidJSON GenericValue, i.e. string
124+
* @param param_dict the map to write to
125+
* @param name the name of the key
126+
*/
119127
static SG_FORCED_INLINE void emplace_string_to_map(
120-
const rapidjson::GenericValue<rapidjson::UTF8<char>>& v,
128+
const GenericValue<UTF8<char>>& v,
121129
std::unordered_map<std::string, std::string>& param_dict,
122130
const std::string& name)
123131
{
124-
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
132+
if (v[name.c_str()].GetType() == Type::kStringType)
125133
param_dict.emplace(name, v[name.c_str()].GetString());
126134
else
127135
param_dict.emplace(name, "");
128136
}
129137

138+
/**
139+
* Helper function to add JSON objects as string in map
140+
* @param v a RapidJSON GenericObject, i.e. array
141+
* @param param_dict the map to write to
142+
* @param name the name of the key
143+
*/
130144
static SG_FORCED_INLINE void emplace_string_to_map(
131-
const rapidjson::GenericObject<
132-
true, rapidjson::GenericValue<rapidjson::UTF8<char>>>& v,
145+
const GenericObject<
146+
true, GenericValue<UTF8<char>>>& v,
133147
std::unordered_map<std::string, std::string>& param_dict,
134148
const std::string& name)
135149
{
136-
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
150+
if (v[name.c_str()].GetType() == Type::kStringType)
137151
param_dict.emplace(name, v[name.c_str()].GetString());
138152
else
139153
param_dict.emplace(name, "");
@@ -234,52 +248,235 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
234248
return std::shared_ptr<OpenMLFlow>();
235249
}
236250

251+
/**
252+
* Class using the Any visitor pattern to convert
253+
* a string to a C++ type that can be used as a parameter
254+
* in a Shogun model.
255+
*/
256+
class StringToShogun : public AnyVisitor
257+
{
258+
public:
259+
explicit StringToShogun(std::shared_ptr<CSGObject> model)
260+
: m_model(model), m_parameter(""), m_string_val(""){};
261+
262+
StringToShogun(
263+
std::shared_ptr<CSGObject> model, const std::string& parameter,
264+
const std::string& string_val)
265+
: m_model(model), m_parameter(parameter), m_string_val(string_val){};
266+
267+
void on(bool* v) final
268+
{
269+
if (!is_null())
270+
{
271+
SG_SDEBUG("bool: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
272+
bool result = strcmp(m_string_val.c_str(), "true") == 0;
273+
m_model->put(m_parameter, result);
274+
}
275+
}
276+
void on(int32_t* v) final
277+
{
278+
if (!is_null())
279+
{
280+
SG_SDEBUG("int32: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
281+
try
282+
{
283+
int32_t result = std::stoi(m_string_val);
284+
m_model->put(m_parameter, result);
285+
}
286+
catch (const std::invalid_argument&)
287+
{
288+
// it's an option, i.e. internally represented
289+
// as an enum but in swig exposed as a string
290+
m_string_val.erase(
291+
std::remove_if(
292+
m_string_val.begin(), m_string_val.end(),
293+
// remove quotes
294+
[](const auto& val) { return val == '\"'; }),
295+
m_string_val.end());
296+
m_model->put(m_parameter, m_string_val);
297+
}
298+
}
299+
}
300+
void on(int64_t* v) final
301+
{
302+
if (!is_null())
303+
{
304+
SG_SDEBUG("int64: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
305+
int64_t result = std::stol(m_string_val);
306+
m_model->put(m_parameter, result);
307+
}
308+
}
309+
void on(float* v) final
310+
{
311+
if (!is_null())
312+
{
313+
SG_SDEBUG("float: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
314+
char* end;
315+
float32_t result = std::strtof(m_string_val.c_str(), &end);
316+
m_model->put(m_parameter, result);
317+
}
318+
}
319+
void on(double* v) final
320+
{
321+
if (!is_null())
322+
{
323+
SG_SDEBUG("double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
324+
char* end;
325+
float64_t result = std::strtod(m_string_val.c_str(), &end);
326+
m_model->put(m_parameter, result);
327+
}
328+
}
329+
void on(long double* v)
330+
{
331+
if (!is_null())
332+
{
333+
SG_SDEBUG("long double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
334+
char* end;
335+
floatmax_t result = std::strtold(m_string_val.c_str(), &end);
336+
m_model->put(m_parameter, result);
337+
}
338+
}
339+
void on(CSGObject** v) final
340+
{
341+
SG_SDEBUG("CSGObject: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
342+
}
343+
void on(SGVector<int>* v) final
344+
{
345+
SG_SDEBUG("SGVector<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
346+
}
347+
void on(SGVector<float>* v) final
348+
{
349+
SG_SDEBUG("SGVector<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
350+
}
351+
void on(SGVector<double>* v) final
352+
{
353+
SG_SDEBUG("SGVector<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
354+
}
355+
void on(SGMatrix<int>* mat) final
356+
{
357+
SG_SDEBUG("SGMatrix<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
358+
}
359+
void on(SGMatrix<float>* mat) final
360+
{
361+
SG_SDEBUG("SGMatrix<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
362+
}
363+
void on(SGMatrix<double>* mat) final
364+
{
365+
SG_SDEBUG("SGMatrix<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
366+
}
367+
368+
bool is_null()
369+
{
370+
bool result = strcmp(m_string_val.c_str(), "null") == 0;
371+
return result;
372+
}
373+
374+
void set_parameter_name(const std::string& name)
375+
{
376+
m_parameter = name;
377+
}
378+
379+
void set_string_value(const std::string& value)
380+
{
381+
m_string_val = value;
382+
}
383+
384+
private:
385+
std::shared_ptr<CSGObject> m_model;
386+
std::string m_parameter;
387+
std::string m_string_val;
388+
};
389+
390+
/**
391+
* Instantiates a CSGObject using a factory
392+
* @param factory_name the name of the factory
393+
* @param algo_name the name of algorithm passed to factory
394+
* @return the instantiated object using a factory
395+
*/
396+
std::shared_ptr<CSGObject> instantiate_model_from_factory(
397+
const std::string& factory_name, const std::string& algo_name)
398+
{
399+
std::shared_ptr<CSGObject> obj;
400+
if (factory_name == "machine")
401+
obj = std::shared_ptr<CSGObject>(machine(algo_name));
402+
else if (factory_name == "kernel")
403+
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
404+
else if (factory_name == "distance")
405+
obj = std::shared_ptr<CSGObject>(distance(algo_name));
406+
else
407+
SG_SERROR("Unsupported factory \"%s\".\n", factory_name.c_str())
408+
409+
return obj;
410+
}
411+
412+
/**
413+
* Downcasts a CSGObject and puts it in the map of obj.
414+
* @param obj the main object
415+
* @param nested_obj the object to be casted and put in the obj map.
416+
* @param parameter_name the name of nested_obj
417+
*/
418+
void cast_and_put(
419+
const std::shared_ptr<CSGObject>& obj,
420+
const std::shared_ptr<CSGObject>& nested_obj,
421+
const std::string& parameter_name)
422+
{
423+
if (auto casted_obj = std::dynamic_pointer_cast<CMachine>(nested_obj))
424+
{
425+
// TODO: remove clone
426+
// temporary fix until shared_ptr PR merged
427+
auto* tmp_clone = dynamic_cast<CMachine*>(casted_obj->clone());
428+
obj->put(parameter_name, tmp_clone);
429+
}
430+
else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
431+
{
432+
auto* tmp_clone = dynamic_cast<CKernel*>(casted_obj->clone());
433+
obj->put(parameter_name, tmp_clone);
434+
}
435+
else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
436+
{
437+
auto* tmp_clone = dynamic_cast<CDistance*>(casted_obj->clone());
438+
obj->put(parameter_name, tmp_clone);
439+
}
440+
else
441+
SG_SERROR("Could not cast SGObject.\n")
442+
}
443+
237444
std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
238445
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
239446
{
240-
std::string name;
241-
std::string val_as_string;
242-
std::shared_ptr<CSGObject> obj;
243447
auto params = flow->get_parameters();
244448
auto components = flow->get_components();
245449
auto class_name = get_class_info(flow->get_class_name());
246450
auto module_name = std::get<0>(class_name);
247451
auto algo_name = std::get<1>(class_name);
248-
if (module_name == "machine")
249-
obj = std::shared_ptr<CSGObject>(machine(algo_name));
250-
else if (module_name == "kernel")
251-
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
252-
else if (module_name == "distance")
253-
obj = std::shared_ptr<CSGObject>(distance(algo_name));
254-
else
255-
SG_SERROR("Unsupported factory \"%s\"\n", module_name.c_str())
452+
453+
auto obj = instantiate_model_from_factory(module_name, algo_name);
256454
auto obj_param = obj->get_params();
257455

258-
auto put_lambda = [&obj, &name, &val_as_string](const auto& val) {
259-
// cast value using type from get, i.e. val
260-
auto val_ = char_to_scalar<std::remove_reference_t<decltype(val)>>(
261-
val_as_string.c_str());
262-
obj->put(name, val_);
263-
};
456+
std::unique_ptr<StringToShogun> visitor(new StringToShogun(obj));
264457

265458
if (initialize_with_defaults)
266459
{
267460
for (const auto& param : params)
268461
{
269462
Any any_val = obj_param.at(param.first)->get_value();
270-
name = param.first;
271-
val_as_string = param.second.at("default_value");
272-
sg_any_dispatch(any_val, sg_all_typemap, put_lambda);
463+
std::string name = param.first;
464+
std::string val_as_string = param.second.at("default_value");
465+
visitor->set_parameter_name(name);
466+
visitor->set_string_value(val_as_string);
467+
any_val.visit(visitor.get());
273468
}
274469
}
275470

276471
for (const auto& component : components)
277472
{
278-
CSGObject* a =
279-
flow_to_model(component.second, initialize_with_defaults).get();
280-
// obj->put(component.first, a);
473+
std::shared_ptr<CSGObject> nested_obj =
474+
flow_to_model(component.second, initialize_with_defaults);
475+
cast_and_put(obj, nested_obj, component.first);
281476
}
282477

478+
SG_SDEBUG("Final object: %s.\n", obj->to_string().c_str());
479+
283480
return obj;
284481
}
285482

@@ -306,15 +503,15 @@ ShogunOpenML::get_class_info(const std::string& class_name)
306503
if (std::next(it) == class_name.end())
307504
class_components.emplace_back(std::string(begin, std::next(it)));
308505
}
309-
if (class_components.size() != 3)
310-
SG_SERROR("Invalid class name format %s\n", class_name.c_str())
311506
if (class_components[0] == "shogun")
312507
result = std::make_tuple(class_components[1], class_components[2]);
313508
else
314509
SG_SERROR(
315510
"The provided flow is not meant for shogun deserialisation! The "
316-
"required library is \"%s\"\n",
511+
"required library is \"%s\".\n",
317512
class_components[0].c_str())
513+
if (class_components.size() != 3)
514+
SG_SERROR("Invalid class name format %s.\n", class_name.c_str())
318515

319516
return result;
320517
}

0 commit comments

Comments
 (0)