Jonah Casebeer1*, Nicholas J. Bryan2, and Paris Smaragdis1
1 Department of Computer Science, University of Illinois at Urbana-Champaign
2 Adobe Research, Lead advisor
*Work performed while an intern at Adobe Research
Table of Contents
Adaptive filtering algorithms are pervasive throughout signal processing and have had a material impact on a wide variety of domains including audio processing, telecommunications, biomedical sensing, astrophysics and cosmology, seismology, and many more. Adaptive filters typically operate via specialized online, iterative optimization methods such as least-mean squares or recursive least squares and aim to process signals in unknown or nonstationary environments. Such algorithms, however, can be slow and laborious to develop, require domain expertise to create, and necessitate mathematical insight for improvement. In this work, we seek to improve upon hand-derived adaptive filter algorithms and present a comprehensive framework for learning online, adaptive signal processing algorithms or update rules directly from data. To do so, we frame the development of adaptive filters as a meta-learning problem in the context of deep learning and use a form of self-supervision to learn online iterative update rules for adaptive filters. To demonstrate our approach, we focus on audio applications and systematically develop meta-learned adaptive filters for five canonical audio problems including system identification, acoustic echo cancellation, blind equalization, multi-channel dereverberation, and beamforming. We compare our approach against common baselines and/or recent state-of-the-art methods. We show we can learn high-performing adaptive filters that operate in real-time and, in most cases, significantly outperform each method we compare against -- all using a single general-purpose configuration of our approach.
For more details, please see: "Meta-AF: Meta-Learning for Adaptive Filters", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, arXiv, 2022. Or, our talk:
If you use ideas or code from this work, please cite our paper:
@article{casebeer2022meta,
title={Meta-AF: Meta-Learning for Adaptive Filters},
author={Casebeer, Jonah and Bryan, Nicholas J and Smaragdis, Paris},
journal={arXiv preprint arXiv:2204.11942},
year={2022}
}
For audio demonstrations of the work and metaaf
package in action, please check out our demo website. You'll be able to find demos for the five core adaptive filtering problems.
We open source all code for the work via our metaaf
python pip package. Our metaaf
package has functionality which enables meta-learning optimizers for near-arbitrary adaptive filters for any differentiable objective. metaaf
automatically manages online overlap-save and overlap-add for single/multi channel and single/multi frame filters. We also include generic implementations of LMS, RMSProp, NLMS, and RLS for benchmarking purposes. Finally, metaaf
includes implementation of generic GRU based optimizers, which are compatible with any filter defined in the metaaf
format. Below, you can find example usage, usage for several common adaptive filter tasks (in the adaptive filter zoo), and installation instructions.
The metaaf
package is relatively small, being limited to a dozen files which enable much more functionality than we demo here. The core meta-learning code is in core.py
, the buffered and online filter implementations are in filter.py
, and the RNN based optimizers are in optimizer_gru.py
and optimizer_fgru.py
. The remaining files hold utilities and generic implementations of baseline optimizers. meta.py
contains a class for managing training.
To install the metaaf
python package, you will need a working JAX install. You can set one up by following the official directions here. Below is an example of the commands we use to setup a new conda environment called metaenv
in which we install metaaf
and any dependencies.
### GPU
# Install all the cuda and cudnn prerequisites
conda create -yn metaenv python=3.7 &&
conda install -yn metaenv cudatoolkit=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudatoolkit-dev=11.1.1 -c pytorch -c conda-forge &&
conda install -yn metaenv cudnn=8.2 -c nvidia -c pytorch -c anaconda -c conda-forge &&
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv
# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install jax[cuda11_cudnn82]==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Install Haiku
pip install git+https://github.com/deepmind/[email protected]
### CPU. x86 only
conda create -yn metaenv python=3.7 &&
conda install -yn metaenv pytorch torchvision torchaudio -c pytorch &&
conda install -yn metaenv pytorch cpuonly -c pytorch -y
conda activate metaenv
# Actually install jax
# You may need to change the cuda/cudnn version numbers depending on your machine
pip install --upgrade pip
pip install --upgrade "jax[cpu]"==0.3.15
# Install Haiku
pip install git+https://github.com/deepmind/[email protected]
Finally, with the prerequisites done, you can install metaaf
by cloning the repo, moving into the base directory, and running pip install -e ./
. This pip install
adds the remaining dependencies. To run the demo notebook, you also need to:
# Add the conda env to your jupyter session
conda install -yn metaenv ipykernel
ipython kernel install --user --name=metaenv
# Install plotting
pip install matplotlib
# Install widgets for a progress bar
pip install ipywidgets
The metaaf
package provides several important modules to facilitate training. The first is the MetaAFTrainer
, a class which manages training. To use the MetaAFTrainer
, we need to define a filter architecture, and a dataset. metaaf
adopts several conventions to simplify training and automate procedures like buffering. In this notebook, we walk through this process and demonstrate on a toy system-identification task. In this section, we explain that toy-task and the automatic argparse utilities. To see a full-scale example, proceed to the next section, where we describe the Meta-AF Zoo.
First, you need to make a datatset using a regular PyTorch dataset. The dataset must return a dictionary with two keys: "signals" and "metadata". The "signals" are automatically indexed and sliced and should be of size samples by channels.
class SystemIDDataset(Dataset):
def __init__(self, N=4096, sys_order=32):
self.N = N
self.sys_order = sys_order
def __len__(self):
return 256
def __getitem__(self, idx):
# the system
w = np.random.normal(size=self.sys_order) / self.sys_order
# the input
u = np.random.normal(size=self.N)
# the output
d = np.convolve(w, u)[: self.N]
return {
"signals": {
"u": u[:, None], # time X channels
"d": d[:, None], # time X channels
},
"metadata": {},
}
train_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
val_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
test_loader = NumpyLoader(SystemIDDataset(), batch_size=32)
Then, you define your filter. We're going to inherit from the metaaf
OLS module. When inheriting, you can return either the current result, which will be automatically buffered, or a dictionary. When returning a dictionary it must have a key "out" which will be buffered. All other keys are stacked and returned.
from metaaf.filter import OverlapSave
# the filter inherits from the overlap save modules
class SystemID(OverlapSave, hk.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# select the analysize window
self.analysis_window = jnp.ones(self.window_size)
# Since we use the OLS base class, x and y are stft domain inputs.
# The filter msut take the same inputs provided in its _fwd function.
def __ols_call__(self, u, d, metadata):
# collect a buffer sized anti-aliased filter
w = self.get_filter(name="w")
# this is n_frames x n_freq x channels or 1 x F x 1 here
return (w * u)[0]
@staticmethod
def add_args(parent_parser):
return super(SystemID, SystemID).add_args(parent_parser)
@staticmethod
def grab_args(kwargs):
return super(SystemID, SystemID).grab_args(kwargs)
Haiku converts objects to functions. We need to provide a wrapper to do this. The wrapper function MUST take as input the same named values from your dataset.
def _SystemID_fwd(u, d, metadata=None, init_data=None, **kwargs):
f = SystemID(**kwargs)
return f(u=u, d=d)
Then, we define an adaptive filter loss. Here, just the MSE. An adaptive filter loss must be written in this form, so that metaaf
can automatically take its gradient and pass it around.
def filter_loss(out, data_samples, metadata):
e = out - data_samples["d"]
return jnp.vdot(e, e) / (e.size)
We can construct the meta-train and meta-val losses in a similar fashion.
def meta_train_loss(losses, outputs, data_samples, metadata, outer_index, outer_learnable):
out = jnp.concatenate(outputs["out"], 0)
return jnp.log(jnp.mean(jnp.abs(out - data_samples["d"]) ** 2) + 1e-9)
def meta_val_loss(losses, outputs, data_samples, metadata, outer_learnable):
out = jnp.reshape(
outputs["out"],
(outputs["out"].shape[0], -1, outputs["out"].shape[-1]),
)
d = data_samples["d"]
min_len = min(out.shape[1], d.shape[1])
return jnp.log(jnp.mean(jnp.abs(out[:, :min_len] - d[:, :min_len]) ** 2) + 1e-9)
With everything defined, we can setup the Meta-Trainer and start training.
from metaaf.optimizer_gru import EGRU
# Collect arguments
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser = EGRU.add_args(parser)
parser = SystemID.add_args(parser)
parser = MetaAFTrainer.add_args(parser)
kwargs = vars(parser.parse_args())
# Setup trainer
system = MetaAFTrainer(
_filter_fwd=_SystemID_fwd,
filter_kwargs=SystemID.grab_args(kwargs),
filter_loss=filter_loss,
meta_train_loss=meta_train_loss,
meta_val_loss=meta_val_loss,
optimizer_kwargs=EGRU.grab_args(kwargs),
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
)
# Train
key = jax.random.PRNGKey(0)
outer_learned, losses = system.train(
**MetaAFTrainer.grab_args(kwargs),
key=key,
)
That is it! For more advanced options check out the zoo, where we demonstrate call backs, customized filters, and more.
The Meta-AF Zoo contains implementations for system identification, acoustic echo cancellation, equalization, weighted predection error dereverberation, and a generalized sidelobe canceller beamformer all in the metaaf
framework. You can find intructions for how to run, evaluate, and setup those models here. For trained weights, and tuned baselines, please see the tagged release zip file here.
All core utility code within the metaaf
folder is licensed via the University of Illinois Open Source License. All code within the zoo
folder and model weights are licensed via the Adobe Research License. Copyright (c) Adobe Systems Incorporated. All rights reserved.
An extension of this work using metaaf
here:
"Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies", Junkai Wu, Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, IWAENC, 2022.
@article{wu2022metalearning,
title={Meta-Learning for Adaptive Filters with Higher-Order Frequency Dependencies},
author={Wu, Junkai and Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},
booktitle={IEEE International Workshop on Acoustic Signal Enhancement (IWAENC)},
year={2022},
}
An early version of this work:
"Auto-DSP: Learning to Optimize Acoustic Echo Cancellers", Jonah Casebeer, Nicholas J. Bryan, and Paris Smaragdis, WASPAA, 2021.
@inproceedings{casebeer2021auto,
title={Auto-DSP: Learning to Optimize Acoustic Echo Cancellers},
author={Casebeer, Jonah and Bryan, Nicholas J. and Smaragdis, Paris},
booktitle={2021 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
pages={291--295},
year={2021},
organization={IEEE}
}