Skip to content

Commit 583a724

Browse files
committed
more tests
1 parent c3acaa1 commit 583a724

File tree

3 files changed

+254
-50
lines changed

3 files changed

+254
-50
lines changed

src/shogun/io/ARFFFile.cpp

Lines changed: 107 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* Authors: Gil Hoben
55
*/
66

7+
#include <shogun/features/DenseFeatures.h>
78
#include <shogun/io/ARFFFile.h>
8-
#include <shogun/mathematics/linalg/LinalgNamespace.h>
99

1010
#include <date/date.h>
1111

@@ -18,6 +18,18 @@ const char* ARFFDeserializer::m_attribute_string = "@attribute";
1818
const char* ARFFDeserializer::m_data_string = "@data";
1919
const char* ARFFDeserializer::m_default_date_format = "%Y-%M-%DT%H:%M:%S";
2020

21+
struct VectorSizeVisitor
22+
{
23+
size_t operator()(const std::vector<float64_t>& v) const
24+
{
25+
return v.size();
26+
}
27+
size_t operator()(const std::vector<std::string>& v) const
28+
{
29+
return v.size();
30+
}
31+
};
32+
2133
void ARFFDeserializer::read()
2234
{
2335
m_line_number = 0;
@@ -78,6 +90,7 @@ void ARFFDeserializer::read()
7890
// check if it is nominal
7991
if (type[0] == '{')
8092
{
93+
// @ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}
8194
std::vector<std::string> attributes;
8295
// split norminal values: "{A, B, C}" to vector{A, B, C}
8396
split(
@@ -86,6 +99,7 @@ void ARFFDeserializer::read()
8699
m_nominal_attributes.emplace_back(
87100
std::make_pair(name, attributes));
88101
m_attributes.push_back(Attribute::Nominal);
102+
m_data_vectors.emplace_back(std::vector<float64_t>{});
89103
return;
90104
}
91105

@@ -120,23 +134,32 @@ void ARFFDeserializer::read()
120134
m_current_line.c_str())
121135
}
122136
m_attributes.push_back(Attribute::Date);
137+
m_data_vectors.emplace_back(std::vector<float64_t>{});
123138
}
124139
else if (is_primitive_type(type))
125140
{
126141
type = string_to_lower(type);
127142
// numeric attributes
128143
if (type == "numeric")
144+
{
129145
m_attributes.push_back(Attribute::Numeric);
146+
m_data_vectors.emplace_back(std::vector<float64_t>{});
147+
}
130148
else if (type == "integer")
149+
{
131150
m_attributes.push_back(Attribute::Integer);
151+
m_data_vectors.emplace_back(std::vector<float64_t>{});
152+
}
132153
else if (type == "real")
154+
{
133155
m_attributes.push_back(Attribute::Real);
156+
m_data_vectors.emplace_back(std::vector<float64_t>{});
157+
}
134158
else if (type == "string")
135159
{
136160
// @ATTRIBUTE LCC string
137-
// m_attributes.emplace(std::make_pair(elems[0],
138-
// "string"));
139161
m_attributes.push_back(Attribute::String);
162+
m_data_vectors.emplace_back(std::vector<std::string>{});
140163
}
141164
else
142165
SG_SERROR(
@@ -180,7 +203,8 @@ void ARFFDeserializer::read()
180203
split(m_current_line, ",", std::back_inserter(elems), "\'\"");
181204
auto nominal_pos = m_nominal_attributes.begin();
182205
auto date_pos = m_date_formats.begin();
183-
for (int i = 0; i < elems.size(); ++i)
206+
int i = 0;
207+
for (; i < elems.size(); ++i)
184208
{
185209
Attribute type = m_attributes[i];
186210
switch (type)
@@ -191,7 +215,8 @@ void ARFFDeserializer::read()
191215
{
192216
try
193217
{
194-
m_data.push_back(std::stod(elems[i]));
218+
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
219+
.push_back(std::stod(elems[i]));
195220
}
196221
catch (const std::invalid_argument&)
197222
{
@@ -216,7 +241,8 @@ void ARFFDeserializer::read()
216241
"Unexpected value \"%s\" on line %d\n",
217242
elems[i].c_str(), m_line_number);
218243
float64_t idx = std::distance(encoding.begin(), pos);
219-
m_data.push_back(idx);
244+
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
245+
.push_back(idx);
220246
nominal_pos = std::next(nominal_pos);
221247
}
222248
break;
@@ -227,49 +253,106 @@ void ARFFDeserializer::read()
227253
if (date_pos == m_date_formats.end())
228254
SG_SERROR(
229255
"Unexpected date value \"%s\" on line %d.\n",
230-
elems[i].c_str(), m_line_number);
256+
elems[i].c_str(), m_line_number);
231257
ss >> date::parse(*date_pos, t);
232258
if (bool(ss))
233259
{
234260
auto value_timestamp = t.time_since_epoch().count();
235-
m_data.emplace_back(value_timestamp);
261+
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
262+
.push_back(value_timestamp);
236263
}
237264
else
238265
SG_SERROR(
239266
"Error parsing date \"%s\" with date format \"%s\" "
240267
"on line %d.\n",
241-
elems[i].c_str(), (*date_pos).c_str(), m_line_number)
268+
elems[i].c_str(), (*date_pos).c_str(), m_line_number)
242269
++date_pos;
243270
}
244271
break;
245272
case (Attribute::String):
246-
SG_SERROR("String parsing not implemented.\n")
273+
shogun::get<std::vector<std::string>>(m_data_vectors[i])
274+
.emplace_back(elems[i]);
247275
}
248276
}
277+
if (i != m_attributes.size())
278+
SG_SERROR(
279+
"Unexpected number of values on line %d, expected %d values, "
280+
"but found %d.\n",
281+
m_line_number, m_attributes.size(), i)
249282
++m_row_count;
250283
};
251284
auto check_data = [this]() {
252285
// check X values
253286
SG_SDEBUG(
254-
"size: %d, cols: %d, rows: %d", m_data.size(),
255-
m_data.size() / m_row_count, m_row_count)
256-
if (!m_data.empty())
287+
"size: %d, cols: %d, rows: %d", m_data_vectors.size(),
288+
m_data_vectors.size() / m_row_count, m_row_count)
289+
if (!m_data_vectors.empty())
257290
{
258-
auto tmp =
259-
SGMatrix<float64_t>(m_data.size() / m_row_count, m_row_count);
260-
m_data_matrix =
261-
SGMatrix<float64_t>(m_row_count, m_data.size() / m_row_count);
262-
memcpy(
263-
tmp.matrix, m_data.data(), m_data.size() * sizeof(float64_t));
264-
typename SGMatrix<float64_t>::EigenMatrixXtMap tmp_eigen = tmp;
265-
typename SGMatrix<float64_t>::EigenMatrixXtMap m_data_matrix_eigen =
266-
m_data_matrix;
267-
268-
m_data_matrix_eigen = tmp_eigen.transpose();
291+
auto feature_count = m_data_vectors.size();
292+
index_t row_count =
293+
shogun::visit(VectorSizeVisitor{}, m_data_vectors[0]);
294+
for (int i = 1; i < feature_count; ++i)
295+
{
296+
REQUIRE(
297+
shogun::visit(VectorSizeVisitor{}, m_data_vectors[i]) ==
298+
row_count,
299+
"All columns must have the same number of features!\n")
300+
}
269301
}
270302
else
271303
return false;
272304
return true;
273305
};
274306
process_chunk(read_data, check_data, true);
275307
}
308+
309+
std::shared_ptr<CCombinedFeatures> ARFFDeserializer::get_features()
310+
{
311+
auto result = std::make_shared<CCombinedFeatures>();
312+
index_t row_count = shogun::visit(VectorSizeVisitor{}, m_data_vectors[0]);
313+
for (int i = 0; i < m_data_vectors.size(); ++i)
314+
{
315+
Attribute att = m_attributes[i];
316+
auto vec = m_data_vectors[i];
317+
switch (att)
318+
{
319+
case Attribute::Numeric:
320+
case Attribute::Integer:
321+
case Attribute::Real:
322+
case Attribute::Date:
323+
case Attribute::Nominal:
324+
{
325+
auto casted_vec = shogun::get<std::vector<float64_t>>(vec);
326+
SGMatrix<float64_t> mat(1, row_count);
327+
memcpy(
328+
mat.matrix, casted_vec.data(),
329+
casted_vec.size() * sizeof(float64_t));
330+
auto* feat = new CDenseFeatures<float64_t>(mat);
331+
result->append_feature_obj(feat);
332+
}
333+
break;
334+
case Attribute::String:
335+
{
336+
auto casted_vec = shogun::get<std::vector<std::string>>(vec);
337+
index_t max_string_length = 0;
338+
for (const auto& el : casted_vec)
339+
{
340+
if (max_string_length < el.size())
341+
max_string_length = el.size();
342+
}
343+
SGStringList<char> strings(row_count, max_string_length);
344+
for (int j = 0; j < row_count; ++j)
345+
{
346+
SGString<char> current(max_string_length);
347+
memcpy(
348+
current.string, casted_vec[j].data(),
349+
(casted_vec.size()+1) * sizeof(char));
350+
strings.strings[j] = current;
351+
}
352+
auto* feat = new CStringFeatures<char>(strings, EAlphabet::RAWBYTE);
353+
result->append_feature_obj(feat);
354+
}
355+
}
356+
}
357+
return result;
358+
}

src/shogun/io/ARFFFile.h

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#define SHOGUN_ARFFFILE_H
99

1010
#include <shogun/base/init.h>
11+
#include <shogun/base/variant.h>
12+
#include <shogun/features/CombinedFeatures.h>
1113
#include <shogun/lib/SGMatrix.h>
1214
#include <shogun/lib/SGVector.h>
1315

@@ -41,15 +43,18 @@ namespace shogun
4143
SG_FORCED_INLINE void left_trim(std::string& s)
4244
{
4345
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](char val) {
44-
return !std::isspace(val);
45-
}));
46+
return !std::isspace(val);
47+
}));
4648
}
4749

4850
SG_FORCED_INLINE void right_trim(std::string& s)
4951
{
50-
s.erase(std::find_if(s.rbegin(), s.rend(), [](char val) {
51-
return !std::isspace(val);
52-
}).base(), s.end());
52+
s.erase(
53+
std::find_if(
54+
s.rbegin(), s.rend(),
55+
[](char val) { return !std::isspace(val); })
56+
.base(),
57+
s.end());
5358
}
5459

5560
SG_FORCED_INLINE std::string trim(std::string line)
@@ -170,7 +175,8 @@ namespace shogun
170175
* @param java_token
171176
* @return
172177
*/
173-
SG_FORCED_INLINE const char* process_javatoken(const std::string& java_token)
178+
SG_FORCED_INLINE const char*
179+
process_javatoken(const std::string& java_token)
174180
{
175181
if (java_token == "yy")
176182
return "%y";
@@ -191,7 +197,7 @@ namespace shogun
191197
if (java_token == "Z")
192198
return "%z";
193199
if (java_token == "z")
194-
return "%Z";
200+
SG_SERROR("Timezone abbreviations are currently not supported.\n")
195201
if (java_token.empty())
196202
return "";
197203
if (java_token == "SSS")
@@ -237,7 +243,8 @@ namespace shogun
237243
return nullptr;
238244
}
239245

240-
SG_FORCED_INLINE std::string javatime_to_cpptime(const std::string& java_time)
246+
SG_FORCED_INLINE std::string
247+
javatime_to_cpptime(const std::string& java_time)
241248
{
242249
std::string cpp_time;
243250
std::string token;
@@ -326,7 +333,7 @@ namespace shogun
326333
"have the right permissions to open it.\n",
327334
filename.c_str())
328335
}
329-
m_stream = std::unique_ptr<std::istream>(static_cast<std::istream*>(file_stream));
336+
m_stream = std::unique_ptr<std::istream>(file_stream);
330337
}
331338

332339
/**
@@ -348,14 +355,20 @@ namespace shogun
348355
void read();
349356

350357
/**
351-
* Returns the data processed after parsing.
352-
* @return matrix with parsed data
358+
* Returns string parsed in @relation line
359+
* @return the relation string
353360
*/
354-
SGMatrix<float64_t> get_data()
361+
SG_FORCED_INLINE std::string get_relation()
355362
{
356-
return m_data_matrix;
363+
return m_relation;
357364
}
358365

366+
/**
367+
* Get combined features from parsed data
368+
* @return
369+
*/
370+
std::shared_ptr<CCombinedFeatures> get_features();
371+
359372
private:
360373
/**
361374
* Processes a chunk. A chunk is defined as a set of lines that
@@ -455,7 +468,8 @@ namespace shogun
455468
m_nominal_attributes;
456469

457470
/** dynamic continuous vector with the parsed data */
458-
std::vector<float64_t> m_data;
471+
std::vector<variant<std::vector<float64_t>, std::vector<std::string>>>
472+
m_data_vectors;
459473
/** sgmatrix with the properly formatted data from m_data */
460474
SGMatrix<float64_t> m_data_matrix;
461475
};

0 commit comments

Comments
 (0)