Skip to content

Commit b877000

Browse files
authored
Update iter_arff.py (#1288)
1 parent 48b0a11 commit b877000

File tree

3 files changed

+122
-19
lines changed

3 files changed

+122
-19
lines changed

docs/releases/unreleased.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ Calling `learn_one` in a pipeline will now update each part of the pipeline in t
2727

2828
- Added `preprocessing.OrdinalEncoder`, to map string features to integers.
2929

30+
## stream
31+
32+
- `stream.iter_arff` now supports sparse data.
33+
- `stream.iter_arff` now supports multi-output targets.
34+
3035
## utils
3136

3237
- Added `utils.random.exponential` to retrieve random samples following an exponential distribution.

river/linear_model/bayesian_lin_reg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class BayesianLinearRegression(base.Regressor):
4141
>>> metric = metrics.MAE()
4242
4343
>>> evaluate.progressive_val_score(dataset, model, metric)
44-
MAE: 0.586432
44+
MAE: 0.586...
4545
4646
>>> x, _ = next(iter(dataset))
4747
>>> model.predict_one(x)
@@ -73,7 +73,7 @@ class BayesianLinearRegression(base.Regressor):
7373
... )
7474
>>> metric = metrics.MAE()
7575
>>> evaluate.progressive_val_score(dataset, model, metric)
76-
MAE: 1.284016
76+
MAE: 1.284...
7777
7878
And here's how it performs with some smoothing:
7979
@@ -84,7 +84,7 @@ class BayesianLinearRegression(base.Regressor):
8484
... )
8585
>>> metric = metrics.MAE()
8686
>>> evaluate.progressive_val_score(dataset, model, metric)
87-
MAE: 0.15906
87+
MAE: 0.159...
8888
8989
Smoothing allows the model to gradually "forget" the past, and focus on the more recent data.
9090
@@ -99,7 +99,7 @@ class BayesianLinearRegression(base.Regressor):
9999
... )
100100
>>> metric = metrics.MAE()
101101
>>> evaluate.progressive_val_score(dataset, model, metric)
102-
MAE: 0.242248
102+
MAE: 0.242...
103103
104104
References
105105
----------

river/stream/iter_arff.py

Lines changed: 113 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

3-
from scipy.io.arff import arffread
3+
import scipy.io.arff
4+
from scipy.io.arff._arffread import read_header
45

56
from river import base
67

78
from . import utils
89

910

1011
def iter_arff(
11-
filepath_or_buffer, target: str | None = None, compression="infer"
12+
filepath_or_buffer, target: str | list[str] | None = None, compression="infer", sparse=False
1213
) -> base.typing.Stream:
1314
"""Iterates over rows from an ARFF file.
1415
@@ -18,11 +19,96 @@ def iter_arff(
1819
Either a string indicating the location of a file, or a buffer object that has a
1920
`read` method.
2021
target
21-
Name of the target field.
22+
Name(s) of the target field. If `None`, then the target field is ignored. If a list of
23+
names is passed, then a dictionary is returned instead of a single value.
2224
compression
2325
For on-the-fly decompression of on-disk data. If this is set to 'infer' and
2426
`filepath_or_buffer` is a path, then the decompression method is inferred for the
2527
following extensions: '.gz', '.zip'.
28+
sparse
29+
Whether the data is sparse or not.
30+
31+
Examples
32+
--------
33+
34+
>>> cars = '''
35+
... @relation CarData
36+
... @attribute make {Toyota, Honda, Ford, Chevrolet}
37+
... @attribute model string
38+
... @attribute year numeric
39+
... @attribute price numeric
40+
... @attribute mpg numeric
41+
... @data
42+
... Toyota, Corolla, 2018, 15000, 30.5
43+
... Honda, Civic, 2019, 16000, 32.2
44+
... Ford, Mustang, 2020, 25000, 25.0
45+
... Chevrolet, Malibu, 2017, 18000, 28.9
46+
... Toyota, Camry, 2019, 22000, 29.8
47+
... '''
48+
>>> with open('cars.arff', mode='w') as f:
49+
... _ = f.write(cars)
50+
51+
>>> from river import stream
52+
53+
>>> for x, y in stream.iter_arff('cars.arff', target='price'):
54+
... print(x, y)
55+
{'make': 'Toyota', 'model': ' Corolla', 'year': 2018.0, 'mpg': 30.5} 15000.0
56+
{'make': 'Honda', 'model': ' Civic', 'year': 2019.0, 'mpg': 32.2} 16000.0
57+
{'make': 'Ford', 'model': ' Mustang', 'year': 2020.0, 'mpg': 25.0} 25000.0
58+
{'make': 'Chevrolet', 'model': ' Malibu', 'year': 2017.0, 'mpg': 28.9} 18000.0
59+
{'make': 'Toyota', 'model': ' Camry', 'year': 2019.0, 'mpg': 29.8} 22000.0
60+
61+
Finally, let's delete the example file.
62+
63+
>>> import os; os.remove('cars.arff')
64+
65+
ARFF files support sparse data. Let's create a sparse ARFF file.
66+
67+
>>> sparse = '''
68+
... % traindata
69+
... @RELATION "traindata: -C 6"
70+
... @ATTRIBUTE y0 {0, 1}
71+
... @ATTRIBUTE y1 {0, 1}
72+
... @ATTRIBUTE y2 {0, 1}
73+
... @ATTRIBUTE y3 {0, 1}
74+
... @ATTRIBUTE y4 {0, 1}
75+
... @ATTRIBUTE y5 {0, 1}
76+
... @ATTRIBUTE X0 NUMERIC
77+
... @ATTRIBUTE X1 NUMERIC
78+
... @ATTRIBUTE X2 NUMERIC
79+
... @DATA
80+
... { 3 1,6 0.863382,8 0.820094 }
81+
... { 2 1,6 0.659761 }
82+
... { 0 1,3 1,6 0.437881,8 0.818882 }
83+
... { 2 1,6 0.676477,7 0.724635,8 0.755123 }
84+
... '''
85+
86+
>>> with open('sparse.arff', mode='w') as f:
87+
... _ = f.write(sparse)
88+
89+
In addition, we'll specify that there are several target fields.
90+
91+
>>> arff_stream = stream.iter_arff(
92+
... 'sparse.arff',
93+
... target=['y0', 'y1', 'y2', 'y3', 'y4', 'y5'],
94+
... sparse=True
95+
... )
96+
97+
>>> for x, y in arff_stream:
98+
... print(x)
99+
... print(y)
100+
{'X0': '0.863382', 'X2': '0.820094'}
101+
{'y0': 0, 'y1': 0, 'y2': 0, 'y3': '1', 'y4': 0, 'y5': 0}
102+
{'X0': '0.659761'}
103+
{'y0': 0, 'y1': 0, 'y2': '1', 'y3': 0, 'y4': 0, 'y5': 0}
104+
{'X0': '0.437881', 'X2': '0.818882'}
105+
{'y0': '1', 'y1': 0, 'y2': 0, 'y3': '1', 'y4': 0, 'y5': 0}
106+
{'X0': '0.676477', 'X1': '0.724635', 'X2': '0.755123'}
107+
{'y0': 0, 'y1': 0, 'y2': '1', 'y3': 0, 'y4': 0, 'y5': 0}
108+
109+
References
110+
----------
111+
[^1]: [ARFF format description from Weka](https://waikato.github.io/weka-wiki/formats_and_processing/arff_stable/)
26112
27113
"""
28114

@@ -32,26 +118,38 @@ def iter_arff(
32118
buffer = utils.open_filepath(buffer, compression)
33119

34120
try:
35-
rel, attrs = arffread.read_header(buffer)
121+
rel, attrs = read_header(buffer)
36122
except ValueError as e:
37123
msg = f"Error while parsing header, error was: {e}"
38-
raise arffread.ParseArffError(msg)
124+
raise scipy.io.arff.ParseArffError(msg)
39125

40126
names = [attr.name for attr in attrs]
41-
types = [float if isinstance(attr, arffread.NumericAttribute) else None for attr in attrs]
127+
# HACK: it's a bit hacky to rely on class name to determine what casting to apply
128+
casts = [float if attr.__class__.__name__ == "NumericAttribute" else None for attr in attrs]
42129

43130
for r in buffer:
44131
if len(r) == 0:
45132
continue
46-
x = {
47-
name: typ(val) if typ else val
48-
for name, typ, val in zip(names, types, r.rstrip().split(","))
49-
}
50-
try:
51-
y = x.pop(target) if target else None
52-
except KeyError as e:
53-
print(r)
54-
raise e
133+
134+
# Read row
135+
if sparse:
136+
x = {}
137+
for s in r.rstrip()[1:-1].strip().split(","):
138+
name_index, val = s.split(" ", 1)
139+
x[names[int(name_index)]] = val
140+
else:
141+
x = {
142+
name: cast(val) if cast else val
143+
for name, cast, val in zip(names, casts, r.rstrip().split(","))
144+
}
145+
146+
# Handle target
147+
y = None
148+
if target is not None:
149+
if isinstance(target, list):
150+
y = {name: x.pop(name, 0) for name in target}
151+
else:
152+
y = x.pop(target) if target else None
55153

56154
yield x, y
57155

0 commit comments

Comments
 (0)