Skip to content

Commit a249a28

Browse files
authored
ref(test): remove torchaudio dependency and update audio processing to just use soundfile (#739)
* refactor: remove torchaudio dependency and update audio processing to use soundfile * test: update skip condition to include Windows platform
1 parent c4dbd00 commit a249a28

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

requirements/test.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ av >=14.0.0
22
coverage ==7.10.*
33
cryptography==45.0.7
44
mosaicml-streaming==0.11.0
5-
torchaudio>=2.7.0,<2.9
65
pytest ==8.4.*
76
pytest-asyncio>=1.0.0
87
pytest-cov ==7.0.0
@@ -16,4 +15,4 @@ polars >1.0.0
1615
lightning
1716
transformers <4.57.0
1817
zstd
19-
soundfile >=0.13.0 # required for torchaudio backend
18+
soundfile >=0.13.0

src/litdata/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
3636
_BOTO3_AVAILABLE = RequirementCache("boto3")
3737
_FSSPEC_AVAILABLE = RequirementCache("fsspec")
38-
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
3938
_ZSTD_AVAILABLE = RequirementCache("zstd")
4039
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
4140
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")

tests/processing/test_data_processor.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import random
55
import sys
6+
import tempfile
67
from contextlib import suppress
78
from functools import partial
89
from io import BytesIO
@@ -16,7 +17,7 @@
1617
import torch
1718
from lightning_utilities.core.imports import RequirementCache
1819

19-
from litdata.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
20+
from litdata.constants import _ZSTD_AVAILABLE
2021
from litdata.processing import data_processor as data_processor_module
2122
from litdata.processing import functions
2223
from litdata.processing.data_processor import (
@@ -1127,29 +1128,28 @@ def test_empty_optimize(tmpdir, inputs):
11271128

11281129

11291130
def create_synthetic_audio_bytes(index) -> dict:
1130-
from io import BytesIO
1131-
1132-
import torchaudio
1131+
import soundfile as sf
11331132

11341133
# load dummy audio as bytes
1135-
data = torch.randn((1, 16000))
1134+
data = torch.randn((1, 16000)).numpy().squeeze() # shape (16000,)
11361135

1137-
# convert tensor to bytes
1138-
with BytesIO() as f:
1139-
torchaudio.save(f, data, 16000, format="wav")
1140-
data = f.getvalue()
1136+
# convert array to bytes
1137+
with tempfile.NamedTemporaryFile(suffix=".wav") as tmp:
1138+
sf.write(tmp.name, data, 16000, format="WAV")
1139+
with open(tmp.name, "rb") as f:
1140+
data = f.read()
11411141

11421142
return {"content": data}
11431143

11441144

1145-
@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
1145+
@pytest.mark.skipif(
1146+
condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported"
1147+
)
11461148
@pytest.mark.parametrize("compression", [None, "zstd"])
1147-
def test_load_torch_audio(tmpdir, compression):
1149+
def test_load_audio_bytes_optimize_and_stream(tmpdir, compression):
11481150
seed_everything(42)
11491151

1150-
import torchaudio
1151-
1152-
torchaudio.set_audio_backend("soundfile")
1152+
import soundfile as sf
11531153

11541154
optimize(
11551155
fn=create_synthetic_audio_bytes,
@@ -1164,30 +1164,32 @@ def test_load_torch_audio(tmpdir, compression):
11641164
sample = dataset[0]
11651165
buffer = BytesIO(sample["content"])
11661166
buffer.seek(0)
1167-
tensor, sample_rate = torchaudio.load(buffer, format="wav")
1167+
data, sample_rate = sf.read(buffer)
1168+
tensor = torch.from_numpy(data).unsqueeze(0)
11681169
assert tensor.shape == torch.Size([1, 16000])
11691170
assert sample_rate == 16000
11701171

11711172

11721173
def create_synthetic_audio_file(filepath) -> dict:
1173-
import torchaudio
1174+
import soundfile as sf
11741175

11751176
# load dummy audio as bytes
1176-
data = torch.randn((1, 16000))
1177+
data = torch.randn((1, 16000)).numpy().squeeze()
11771178

1178-
# convert tensor to bytes
1179-
with open(filepath, "wb") as f:
1180-
torchaudio.save(f, data, 16000, format="wav")
1179+
# convert array to bytes
1180+
sf.write(filepath, data, 16000, format="WAV")
11811181

11821182
return filepath
11831183

11841184

1185-
@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
1185+
@pytest.mark.skipif(
1186+
condition=not _ZSTD_AVAILABLE or sys.platform == "win32", reason="Requires: ['zstd'] or Windows not supported"
1187+
)
11861188
@pytest.mark.parametrize("compression", [None])
1187-
def test_load_torch_audio_from_wav_file(tmpdir, compression):
1189+
def test_load_audio_file_optimize_and_stream(tmpdir, compression):
11881190
seed_everything(42)
11891191

1190-
import torchaudio
1192+
import soundfile as sf
11911193

11921194
optimize(
11931195
fn=create_synthetic_audio_file,
@@ -1200,9 +1202,10 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
12001202

12011203
dataset = StreamingDataset(input_dir=str(tmpdir))
12021204
sample = dataset[0]
1203-
tensor = torchaudio.load(sample)
1204-
assert tensor[0].shape == torch.Size([1, 16000])
1205-
assert tensor[1] == 16000
1205+
data, sample_rate = sf.read(sample)
1206+
tensor = torch.from_numpy(data).unsqueeze(0)
1207+
assert tensor.shape == torch.Size([1, 16000])
1208+
assert sample_rate == 16000
12061209

12071210

12081211
def test_is_path_valid_in_studio(monkeypatch, tmpdir):

0 commit comments

Comments
 (0)