15
15
using namespace shogun ;
16
16
using namespace rapidjson ;
17
17
18
+ /* *
19
+ * The writer callback function used to write the packets to a C++ string.
20
+ * @param data the data received in CURL request
21
+ * @param size always 1
22
+ * @param nmemb the size of data
23
+ * @param buffer_in the buffer to write to
24
+ * @return the size of buffer that was written
25
+ */
18
26
size_t writer (char * data, size_t size, size_t nmemb, std::string* buffer_in)
19
27
{
20
28
// adapted from https://stackoverflow.com/a/5780603
@@ -30,13 +38,16 @@ size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
30
38
return 0 ;
31
39
}
32
40
41
+ /* OpenML server format */
33
42
const char * OpenMLReader::xml_server = " https://www.openml.org/api/v1/xml" ;
34
43
const char * OpenMLReader::json_server = " https://www.openml.org/api/v1/json" ;
44
+ /* DATA API */
35
45
const char * OpenMLReader::dataset_description = " /data/{}" ;
36
46
const char * OpenMLReader::list_data_qualities = " /data/qualities/list" ;
37
47
const char * OpenMLReader::data_features = " /data/features/{}" ;
38
48
const char * OpenMLReader::list_dataset_qualities = " /data/qualities/{}" ;
39
49
const char * OpenMLReader::list_dataset_filter = " /data/list/{}" ;
50
+ /* FLOW API */
40
51
const char * OpenMLReader::flow_file = " /flow/{}" ;
41
52
42
53
const std::unordered_map<std::string, std::string>
@@ -84,25 +95,16 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
84
95
if (code != CURLE_OK)
85
96
{
86
97
// TODO: call curl_easy_cleanup(curl_handle) ?
87
- SG_SERROR (" Curl error: %s\n " , curl_easy_strerror (code))
98
+ SG_SERROR (" Connection error: %s. \n " , curl_easy_strerror (code))
88
99
}
89
- // else
90
- // {
91
- // long response_code;
92
- // curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
93
- // &response_code); if (response_code == 200) return;
94
- // else
95
- // {
96
- // if (response_code == 181)
97
- // SG_SERROR("Unknown flow. The flow with the given ID was not
98
- // found in the database.") else if (response_code == 180)
99
- // SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
100
- // }
101
- // }
102
100
}
103
101
104
102
#endif // HAVE_CURL
105
103
104
+ /* *
105
+ * Checks the returned flow in JSON format
106
+ * @param doc the parsed flow
107
+ */
106
108
static void check_flow_response (rapidjson::Document& doc)
107
109
{
108
110
if (SG_UNLIKELY (doc.HasMember (" error" )))
@@ -116,24 +118,36 @@ static void check_flow_response(rapidjson::Document& doc)
116
118
REQUIRE (doc.HasMember (" flow" ), " Unexpected format of OpenML flow.\n " );
117
119
}
118
120
121
+ /* *
122
+ * Helper function to add JSON objects as string in map
123
+ * @param v a RapidJSON GenericValue, i.e. string
124
+ * @param param_dict the map to write to
125
+ * @param name the name of the key
126
+ */
119
127
static SG_FORCED_INLINE void emplace_string_to_map (
120
- const rapidjson:: GenericValue<rapidjson:: UTF8<char >>& v,
128
+ const GenericValue<UTF8<char >>& v,
121
129
std::unordered_map<std::string, std::string>& param_dict,
122
130
const std::string& name)
123
131
{
124
- if (v[name.c_str ()].GetType () == rapidjson:: Type::kStringType )
132
+ if (v[name.c_str ()].GetType () == Type::kStringType )
125
133
param_dict.emplace (name, v[name.c_str ()].GetString ());
126
134
else
127
135
param_dict.emplace (name, " " );
128
136
}
129
137
138
+ /* *
139
+ * Helper function to add JSON objects as string in map
140
+ * @param v a RapidJSON GenericObject, i.e. array
141
+ * @param param_dict the map to write to
142
+ * @param name the name of the key
143
+ */
130
144
static SG_FORCED_INLINE void emplace_string_to_map (
131
- const rapidjson:: GenericObject<
132
- true , rapidjson:: GenericValue<rapidjson:: UTF8<char >>>& v,
145
+ const GenericObject<
146
+ true , GenericValue<UTF8<char >>>& v,
133
147
std::unordered_map<std::string, std::string>& param_dict,
134
148
const std::string& name)
135
149
{
136
- if (v[name.c_str ()].GetType () == rapidjson:: Type::kStringType )
150
+ if (v[name.c_str ()].GetType () == Type::kStringType )
137
151
param_dict.emplace (name, v[name.c_str ()].GetString ());
138
152
else
139
153
param_dict.emplace (name, " " );
@@ -234,52 +248,235 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
234
248
return std::shared_ptr<OpenMLFlow>();
235
249
}
236
250
251
+ /* *
252
+ * Class using the Any visitor pattern to convert
253
+ * a string to a C++ type that can be used as a parameter
254
+ * in a Shogun model.
255
+ */
256
+ class StringToShogun : public AnyVisitor
257
+ {
258
+ public:
259
+ explicit StringToShogun (std::shared_ptr<CSGObject> model)
260
+ : m_model(model), m_parameter(" " ), m_string_val(" " ){};
261
+
262
+ StringToShogun (
263
+ std::shared_ptr<CSGObject> model, const std::string& parameter,
264
+ const std::string& string_val)
265
+ : m_model(model), m_parameter(parameter), m_string_val(string_val){};
266
+
267
+ void on (bool * v) final
268
+ {
269
+ if (!is_null ())
270
+ {
271
+ SG_SDEBUG (" bool: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
272
+ bool result = strcmp (m_string_val.c_str (), " true" ) == 0 ;
273
+ m_model->put (m_parameter, result);
274
+ }
275
+ }
276
+ void on (int32_t * v) final
277
+ {
278
+ if (!is_null ())
279
+ {
280
+ SG_SDEBUG (" int32: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
281
+ try
282
+ {
283
+ int32_t result = std::stoi (m_string_val);
284
+ m_model->put (m_parameter, result);
285
+ }
286
+ catch (const std::invalid_argument&)
287
+ {
288
+ // it's an option, i.e. internally represented
289
+ // as an enum but in swig exposed as a string
290
+ m_string_val.erase (
291
+ std::remove_if (
292
+ m_string_val.begin (), m_string_val.end (),
293
+ // remove quotes
294
+ [](const auto & val) { return val == ' \" ' ; }),
295
+ m_string_val.end ());
296
+ m_model->put (m_parameter, m_string_val);
297
+ }
298
+ }
299
+ }
300
+ void on (int64_t * v) final
301
+ {
302
+ if (!is_null ())
303
+ {
304
+ SG_SDEBUG (" int64: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
305
+ int64_t result = std::stol (m_string_val);
306
+ m_model->put (m_parameter, result);
307
+ }
308
+ }
309
+ void on (float * v) final
310
+ {
311
+ if (!is_null ())
312
+ {
313
+ SG_SDEBUG (" float: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
314
+ char * end;
315
+ float32_t result = std::strtof (m_string_val.c_str (), &end);
316
+ m_model->put (m_parameter, result);
317
+ }
318
+ }
319
+ void on (double * v) final
320
+ {
321
+ if (!is_null ())
322
+ {
323
+ SG_SDEBUG (" double: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
324
+ char * end;
325
+ float64_t result = std::strtod (m_string_val.c_str (), &end);
326
+ m_model->put (m_parameter, result);
327
+ }
328
+ }
329
+ void on (long double * v)
330
+ {
331
+ if (!is_null ())
332
+ {
333
+ SG_SDEBUG (" long double: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
334
+ char * end;
335
+ floatmax_t result = std::strtold (m_string_val.c_str (), &end);
336
+ m_model->put (m_parameter, result);
337
+ }
338
+ }
339
+ void on (CSGObject** v) final
340
+ {
341
+ SG_SDEBUG (" CSGObject: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
342
+ }
343
+ void on (SGVector<int >* v) final
344
+ {
345
+ SG_SDEBUG (" SGVector<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
346
+ }
347
+ void on (SGVector<float >* v) final
348
+ {
349
+ SG_SDEBUG (" SGVector<float>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
350
+ }
351
+ void on (SGVector<double >* v) final
352
+ {
353
+ SG_SDEBUG (" SGVector<double>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
354
+ }
355
+ void on (SGMatrix<int >* mat) final
356
+ {
357
+ SG_SDEBUG (" SGMatrix<int>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
358
+ }
359
+ void on (SGMatrix<float >* mat) final
360
+ {
361
+ SG_SDEBUG (" SGMatrix<float>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
362
+ }
363
+ void on (SGMatrix<double >* mat) final
364
+ {
365
+ SG_SDEBUG (" SGMatrix<double>: %s=%s\n " , m_parameter.c_str (), m_string_val.c_str ())
366
+ }
367
+
368
+ bool is_null ()
369
+ {
370
+ bool result = strcmp (m_string_val.c_str (), " null" ) == 0 ;
371
+ return result;
372
+ }
373
+
374
+ void set_parameter_name (const std::string& name)
375
+ {
376
+ m_parameter = name;
377
+ }
378
+
379
+ void set_string_value (const std::string& value)
380
+ {
381
+ m_string_val = value;
382
+ }
383
+
384
+ private:
385
+ std::shared_ptr<CSGObject> m_model;
386
+ std::string m_parameter;
387
+ std::string m_string_val;
388
+ };
389
+
390
+ /* *
391
+ * Instantiates a CSGObject using a factory
392
+ * @param factory_name the name of the factory
393
+ * @param algo_name the name of algorithm passed to factory
394
+ * @return the instantiated object using a factory
395
+ */
396
+ std::shared_ptr<CSGObject> instantiate_model_from_factory (
397
+ const std::string& factory_name, const std::string& algo_name)
398
+ {
399
+ std::shared_ptr<CSGObject> obj;
400
+ if (factory_name == " machine" )
401
+ obj = std::shared_ptr<CSGObject>(machine (algo_name));
402
+ else if (factory_name == " kernel" )
403
+ obj = std::shared_ptr<CSGObject>(kernel (algo_name));
404
+ else if (factory_name == " distance" )
405
+ obj = std::shared_ptr<CSGObject>(distance (algo_name));
406
+ else
407
+ SG_SERROR (" Unsupported factory \" %s\" .\n " , factory_name.c_str ())
408
+
409
+ return obj;
410
+ }
411
+
412
+ /* *
413
+ * Downcasts a CSGObject and puts it in the map of obj.
414
+ * @param obj the main object
415
+ * @param nested_obj the object to be casted and put in the obj map.
416
+ * @param parameter_name the name of nested_obj
417
+ */
418
+ void cast_and_put (
419
+ const std::shared_ptr<CSGObject>& obj,
420
+ const std::shared_ptr<CSGObject>& nested_obj,
421
+ const std::string& parameter_name)
422
+ {
423
+ if (auto casted_obj = std::dynamic_pointer_cast<CMachine>(nested_obj))
424
+ {
425
+ // TODO: remove clone
426
+ // temporary fix until shared_ptr PR merged
427
+ auto * tmp_clone = dynamic_cast <CMachine*>(casted_obj->clone ());
428
+ obj->put (parameter_name, tmp_clone);
429
+ }
430
+ else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
431
+ {
432
+ auto * tmp_clone = dynamic_cast <CKernel*>(casted_obj->clone ());
433
+ obj->put (parameter_name, tmp_clone);
434
+ }
435
+ else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
436
+ {
437
+ auto * tmp_clone = dynamic_cast <CDistance*>(casted_obj->clone ());
438
+ obj->put (parameter_name, tmp_clone);
439
+ }
440
+ else
441
+ SG_SERROR (" Could not cast SGObject.\n " )
442
+ }
443
+
237
444
std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model (
238
445
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
239
446
{
240
- std::string name;
241
- std::string val_as_string;
242
- std::shared_ptr<CSGObject> obj;
243
447
auto params = flow->get_parameters ();
244
448
auto components = flow->get_components ();
245
449
auto class_name = get_class_info (flow->get_class_name ());
246
450
auto module_name = std::get<0 >(class_name);
247
451
auto algo_name = std::get<1 >(class_name);
248
- if (module_name == " machine" )
249
- obj = std::shared_ptr<CSGObject>(machine (algo_name));
250
- else if (module_name == " kernel" )
251
- obj = std::shared_ptr<CSGObject>(kernel (algo_name));
252
- else if (module_name == " distance" )
253
- obj = std::shared_ptr<CSGObject>(distance (algo_name));
254
- else
255
- SG_SERROR (" Unsupported factory \" %s\"\n " , module_name.c_str ())
452
+
453
+ auto obj = instantiate_model_from_factory (module_name, algo_name);
256
454
auto obj_param = obj->get_params ();
257
455
258
- auto put_lambda = [&obj, &name, &val_as_string](const auto & val) {
259
- // cast value using type from get, i.e. val
260
- auto val_ = char_to_scalar<std::remove_reference_t <decltype (val)>>(
261
- val_as_string.c_str ());
262
- obj->put (name, val_);
263
- };
456
+ std::unique_ptr<StringToShogun> visitor (new StringToShogun (obj));
264
457
265
458
if (initialize_with_defaults)
266
459
{
267
460
for (const auto & param : params)
268
461
{
269
462
Any any_val = obj_param.at (param.first )->get_value ();
270
- name = param.first ;
271
- val_as_string = param.second .at (" default_value" );
272
- sg_any_dispatch (any_val, sg_all_typemap, put_lambda);
463
+ std::string name = param.first ;
464
+ std::string val_as_string = param.second .at (" default_value" );
465
+ visitor->set_parameter_name (name);
466
+ visitor->set_string_value (val_as_string);
467
+ any_val.visit (visitor.get ());
273
468
}
274
469
}
275
470
276
471
for (const auto & component : components)
277
472
{
278
- CSGObject* a =
279
- flow_to_model (component.second , initialize_with_defaults). get () ;
280
- // obj->put( component.first, a );
473
+ std::shared_ptr< CSGObject> nested_obj =
474
+ flow_to_model (component.second , initialize_with_defaults);
475
+ cast_and_put (obj, nested_obj, component.first );
281
476
}
282
477
478
+ SG_SDEBUG (" Final object: %s.\n " , obj->to_string ().c_str ());
479
+
283
480
return obj;
284
481
}
285
482
@@ -306,15 +503,15 @@ ShogunOpenML::get_class_info(const std::string& class_name)
306
503
if (std::next (it) == class_name.end ())
307
504
class_components.emplace_back (std::string (begin, std::next (it)));
308
505
}
309
- if (class_components.size () != 3 )
310
- SG_SERROR (" Invalid class name format %s\n " , class_name.c_str ())
311
506
if (class_components[0 ] == " shogun" )
312
507
result = std::make_tuple (class_components[1 ], class_components[2 ]);
313
508
else
314
509
SG_SERROR (
315
510
" The provided flow is not meant for shogun deserialisation! The "
316
- " required library is \" %s\"\n " ,
511
+ " required library is \" %s\" . \n " ,
317
512
class_components[0 ].c_str ())
513
+ if (class_components.size () != 3 )
514
+ SG_SERROR (" Invalid class name format %s.\n " , class_name.c_str ())
318
515
319
516
return result;
320
517
}
0 commit comments