Skip to content

Commit 1f387d1

Browse files
authored
Merge pull request #99 from bmcfee/filter-cache
Filter caching
2 parents 5641595 + f559573 commit 1f387d1

File tree

4 files changed

+53
-11
lines changed

4 files changed

+53
-11
lines changed

resampy/core.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def resample(x, sr_orig, sr_new, axis=-1, filter='kaiser_best', parallel=True, *
117117
interp_win, precision, _ = get_filter(filter, **kwargs)
118118

119119
if sample_ratio < 1:
120-
interp_win *= sample_ratio
120+
# Make a copy to prevent modifying the filters in place
121+
interp_win = sample_ratio * interp_win
121122

122123
interp_delta = np.zeros_like(interp_win)
123124
interp_delta[:-1] = np.diff(interp_win)
@@ -205,7 +206,7 @@ def resample_nu(x, sr_orig, t_out, axis=-1, filter='kaiser_best', parallel=True,
205206

206207
t_out = np.asarray(t_out)
207208
if t_out.ndim != 1:
208-
raise ValueError('Invalide t_out shape ({}), 1D array expected'.format(t_out.shape))
209+
raise ValueError('Invalid t_out shape ({}), 1D array expected'.format(t_out.shape))
209210
if np.min(t_out) < 0 or np.max(t_out) > (x.shape[axis] - 1) / sr_orig:
210211
raise ValueError('Output domain [{}, {}] exceeds the data domain [0, {}]'.format(
211212
np.min(t_out), np.max(t_out), (x.shape[axis] - 1) / sr_orig))
@@ -214,10 +215,6 @@ def resample_nu(x, sr_orig, t_out, axis=-1, filter='kaiser_best', parallel=True,
214215
shape = list(x.shape)
215216
shape[axis] = len(t_out)
216217

217-
if shape[axis] < 1:
218-
raise ValueError('Input signal length={} is too small to '
219-
'resample from {}->{}'.format(x.shape[axis], x.shape[axis], len(t_out)))
220-
221218
y = np.zeros_like(x, shape=shape)
222219

223220
interp_win, precision, _ = get_filter(filter, **kwargs)

resampy/filters.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
import sys
5252

5353
FILTER_FUNCTIONS = ['sinc_window']
54+
FILTER_CACHE = dict()
5455

55-
__all__ = ['get_filter'] + FILTER_FUNCTIONS
56+
__all__ = ['get_filter', 'clear_cache'] + FILTER_FUNCTIONS
5657

5758

5859
def sinc_window(num_zeros=64, precision=9, window=None, rolloff=0.945):
@@ -201,9 +202,21 @@ def load_filter(filter_name):
201202
The roll-off frequency of the filter, as a fraction of Nyquist
202203
'''
203204

204-
fname = os.path.join('data',
205-
os.path.extsep.join([filter_name, 'npz']))
205+
if filter_name not in FILTER_CACHE:
206+
fname = os.path.join('data',
207+
os.path.extsep.join([filter_name, 'npz']))
206208

207-
data = np.load(pkg_resources.resource_filename(__name__, fname))
209+
data = np.load(pkg_resources.resource_filename(__name__, fname))
210+
FILTER_CACHE[filter_name] = data['half_window'], data['precision'], data['rolloff']
208211

209-
return data['half_window'], data['precision'], data['rolloff']
212+
return FILTER_CACHE[filter_name]
213+
214+
215+
def clear_cache():
216+
'''Clear the filter cache.
217+
218+
Calling this function will ensure that packaged filters are reloaded
219+
upon the next usage.
220+
'''
221+
222+
FILTER_CACHE.clear()

tests/test_core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@ def test_bad_sr(sr_orig, sr_new):
4242
resampy.resample(x, sr_orig, sr_new)
4343

4444

45+
@pytest.mark.xfail(raises=ValueError, strict=True)
46+
@pytest.mark.parametrize('sr', [0, -1])
47+
def test_bad_sr_nu(sr):
48+
x = np.zeros(100)
49+
t = np.arange(3)
50+
resampy.resample_nu(x, sr, t)
51+
52+
53+
@pytest.mark.xfail(raises=ValueError, strict=True)
54+
@pytest.mark.parametrize('t', [np.empty(0), np.eye(3)])
55+
def test_bad_time_nu(t):
56+
x = np.zeros(100)
57+
resampy.resample_nu(x, 1, t)
58+
59+
4560
@pytest.mark.xfail(raises=ValueError, strict=True)
4661
@pytest.mark.parametrize('rolloff', [-1, 1.5])
4762
def test_bad_rolloff(rolloff):

tests/test_filters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
33

4+
import numpy as np
45
import scipy
56
import pytest
67

@@ -39,3 +40,19 @@ def test_filter_load():
3940
@pytest.mark.xfail(raises=NotImplementedError, strict=True)
4041
def test_filter_missing():
4142
resampy.filters.get_filter('bad name')
43+
44+
45+
@pytest.mark.parametrize('sr1, sr2', [(1, 2), (2, 1)])
46+
def test_filter_cache_reset(sr1, sr2):
47+
x = np.random.randn(100)
48+
y1 = resampy.resample(x, sr1, sr2, filter='kaiser_fast')
49+
50+
assert len(resampy.filters.FILTER_CACHE) > 0
51+
52+
resampy.filters.clear_cache()
53+
54+
assert len(resampy.filters.FILTER_CACHE) == 0
55+
56+
y2 = resampy.resample(x, sr1, sr2, filter='kaiser_fast')
57+
58+
assert np.allclose(y1, y2)

0 commit comments

Comments
 (0)