Skip to content

Commit 9a89a0d

Browse files
committed
add filter test
1 parent b7fd358 commit 9a89a0d

File tree

3 files changed

+82
-8
lines changed

3 files changed

+82
-8
lines changed

fsspec/caching.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,11 +652,28 @@ def __init__(
652652
else:
653653
self.data = {}
654654

655+
@property
656+
def size(self):
657+
return sum(_[1] - _[0] for _ in self.data)
658+
659+
@size.setter
660+
def size(self, value):
661+
pass
662+
663+
@property
664+
def nblocks(self):
665+
return len(self.data)
666+
667+
@nblocks.setter
668+
def nblocks(self, value):
669+
pass
670+
655671
def _fetch(self, start: int | None, stop: int | None) -> bytes:
656672
if start is None:
657673
start = 0
658674
if stop is None:
659675
stop = self.size
676+
self.total_requested_bytes += stop - start
660677

661678
out = b""
662679
started = False
@@ -665,24 +682,34 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
665682
if (loc0 <= start < loc1) and (loc0 <= stop <= loc1):
666683
# entirely within the block
667684
off = start - loc0
685+
self.hit_count += 1
668686
return self.data[(loc0, loc1)][off : off + stop - start]
687+
if stop <= loc0:
688+
break
669689
if started and loc0 > loc_old:
670690
# a gap where we need data
691+
self.miss_count += 1
671692
if self.strict:
672693
raise ValueError
673694
out += b"\x00" * (loc0 - loc_old)
674695
if loc0 <= start < loc1:
675696
# found the start
697+
self.hit_count += 1
676698
off = start - loc0
677699
out = self.data[(loc0, loc1)][off : off + stop - start]
678700
started = True
679701
elif start < loc0 and stop > loc1:
680702
# the whole block
703+
self.hit_count += 1
681704
out += self.data[(loc0, loc1)]
682705
elif loc0 <= stop <= loc1:
683706
# end block
707+
self.hit_count += 1
684708
return out + self.data[(loc0, loc1)][: stop - loc0]
685709
loc_old = loc1
710+
self.miss_count += 1
711+
if started and not self.strict:
712+
return out + b"\x00" * (stop - loc_old)
686713
raise ValueError
687714

688715

fsspec/tests/test_caches.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,37 @@ def test_cache_basic(Cache_imp, blocksize, size_requests):
225225
assert result == expected
226226

227227

228+
@pytest.mark.parametrize("strict", [True, False])
228229
@pytest.mark.parametrize("sort", [True, False])
229-
def test_known(sort):
230-
parts = {(10, 20): b"1" * 10, (20, 30): b"2" * 10, (0, 10): b"0" * 10}
230+
def test_known(strict, sort):
231+
parts = {
232+
(10, 20): b"1" * 10,
233+
(20, 30): b"2" * 10,
234+
(0, 10): b"0" * 10,
235+
(40, 50): b"3" * 10,
236+
}
231237
if sort:
232238
parts = dict(sorted(parts.items()))
233-
c = caches["parts"](None, None, 100, parts)
239+
c = caches["parts"](None, None, 100, parts, strict=strict)
240+
assert c.size == 40
234241
assert (0, 30) in c.data # got consolidated
242+
assert c.nblocks == 2
243+
235244
assert c._fetch(5, 15) == b"0" * 5 + b"1" * 5
236245
assert c._fetch(15, 25) == b"1" * 5 + b"2" * 5
237-
# Over-read will raise error
238-
with pytest.raises(ValueError):
239-
# tries to call None fetcher
240-
c._fetch(25, 35)
246+
assert c.hit_count
247+
assert not c.miss_count
248+
249+
if strict:
250+
# Over-read will raise error
251+
with pytest.raises(ValueError):
252+
c._fetch(25, 35)
253+
with pytest.raises(ValueError):
254+
c._fetch(25, 45)
255+
else:
256+
assert c._fetch(25, 35) == b"2" * 5 + b"\x00" * 5
257+
assert c._fetch(25, 45) == b"2" * 5 + b"\x00" * 10 + b"3" * 5
258+
assert c.miss_count
241259

242260

243261
def test_background(server, monkeypatch):

fsspec/tests/test_parquet.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@pytest.fixture(
2323
params=[
24-
# pytest.param("fastparquet", marks=FASTPARQUET_MARK),
24+
pytest.param("fastparquet", marks=FASTPARQUET_MARK),
2525
pytest.param("pyarrow", marks=PYARROW_MARK),
2626
]
2727
)
@@ -145,3 +145,32 @@ def test_open_parquet_file(
145145
max_block=max_block,
146146
footer_sample_size=footer_sample_size,
147147
)
148+
149+
150+
@FASTPARQUET_MARK
151+
def test_with_filter(tmpdir):
152+
import pandas as pd
153+
154+
df = pd.DataFrame(
155+
{
156+
"a": [10, 1, 2, 3, 7, 8, 9],
157+
"b": ["a", "a", "a", "b", "b", "b", "b"],
158+
}
159+
)
160+
fn = os.path.join(str(tmpdir), "test.parquet")
161+
df.to_parquet(fn, engine="fastparquet", row_group_offsets=[0, 3], stats=True)
162+
163+
expect = pd.read_parquet(fn, engine="fastparquet", filters=[["b", "==", "b"]])
164+
f = open_parquet_file(
165+
fn,
166+
engine="fastparquet",
167+
filters=[["b", "==", "b"]],
168+
max_gap=1,
169+
max_block=1,
170+
footer_sample_size=8,
171+
)
172+
assert (0, 4) in f.cache.data
173+
assert f.cache.size < os.path.getsize(fn)
174+
175+
result = pd.read_parquet(f, engine="fastparquet", filters=[["b", "==", "b"]])
176+
pd.testing.assert_frame_equal(expect, result)

0 commit comments

Comments
 (0)