33import os
44import random
55import sys
6+ import tempfile
67from contextlib import suppress
78from functools import partial
89from io import BytesIO
1617import torch
1718from lightning_utilities .core .imports import RequirementCache
1819
19- from litdata .constants import _TORCH_AUDIO_AVAILABLE , _ZSTD_AVAILABLE
20+ from litdata .constants import _ZSTD_AVAILABLE
2021from litdata .processing import data_processor as data_processor_module
2122from litdata .processing import functions
2223from litdata .processing .data_processor import (
@@ -1127,29 +1128,28 @@ def test_empty_optimize(tmpdir, inputs):
11271128
11281129
11291130def create_synthetic_audio_bytes (index ) -> dict :
1130- from io import BytesIO
1131-
1132- import torchaudio
1131+ import soundfile as sf
11331132
11341133 # load dummy audio as bytes
1135- data = torch .randn ((1 , 16000 ))
1134+ data = torch .randn ((1 , 16000 )). numpy (). squeeze () # shape (16000,)
11361135
1137- # convert tensor to bytes
1138- with BytesIO () as f :
1139- torchaudio .save (f , data , 16000 , format = "wav" )
1140- data = f .getvalue ()
1136+ # convert array to bytes
1137+ with tempfile .NamedTemporaryFile (suffix = ".wav" ) as tmp :
1138+ sf .write (tmp .name , data , 16000 , format = "WAV" )
1139+ with open (tmp .name , "rb" ) as f :
1140+ data = f .read ()
11411141
11421142 return {"content" : data }
11431143
11441144
1145- @pytest .mark .skipif (condition = not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE , reason = "Requires: ['torchaudio']" )
1145+ @pytest .mark .skipif (
1146+ condition = not _ZSTD_AVAILABLE or sys .platform == "win32" , reason = "Requires: ['zstd'] or Windows not supported"
1147+ )
11461148@pytest .mark .parametrize ("compression" , [None , "zstd" ])
1147- def test_load_torch_audio (tmpdir , compression ):
1149+ def test_load_audio_bytes_optimize_and_stream (tmpdir , compression ):
11481150 seed_everything (42 )
11491151
1150- import torchaudio
1151-
1152- torchaudio .set_audio_backend ("soundfile" )
1152+ import soundfile as sf
11531153
11541154 optimize (
11551155 fn = create_synthetic_audio_bytes ,
@@ -1164,30 +1164,32 @@ def test_load_torch_audio(tmpdir, compression):
11641164 sample = dataset [0 ]
11651165 buffer = BytesIO (sample ["content" ])
11661166 buffer .seek (0 )
1167- tensor , sample_rate = torchaudio .load (buffer , format = "wav" )
1167+ data , sample_rate = sf .read (buffer )
1168+ tensor = torch .from_numpy (data ).unsqueeze (0 )
11681169 assert tensor .shape == torch .Size ([1 , 16000 ])
11691170 assert sample_rate == 16000
11701171
11711172
11721173def create_synthetic_audio_file (filepath ) -> dict :
1173- import torchaudio
1174+ import soundfile as sf
11741175
11751176 # load dummy audio as bytes
1176- data = torch .randn ((1 , 16000 ))
1177+ data = torch .randn ((1 , 16000 )). numpy (). squeeze ()
11771178
1178- # convert tensor to bytes
1179- with open (filepath , "wb" ) as f :
1180- torchaudio .save (f , data , 16000 , format = "wav" )
1179+ # convert array to bytes
1180+ sf .write (filepath , data , 16000 , format = "WAV" )
11811181
11821182 return filepath
11831183
11841184
1185- @pytest .mark .skipif (condition = not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE , reason = "Requires: ['torchaudio']" )
1185+ @pytest .mark .skipif (
1186+ condition = not _ZSTD_AVAILABLE or sys .platform == "win32" , reason = "Requires: ['zstd'] or Windows not supported"
1187+ )
11861188@pytest .mark .parametrize ("compression" , [None ])
1187- def test_load_torch_audio_from_wav_file (tmpdir , compression ):
1189+ def test_load_audio_file_optimize_and_stream (tmpdir , compression ):
11881190 seed_everything (42 )
11891191
1190- import torchaudio
1192+ import soundfile as sf
11911193
11921194 optimize (
11931195 fn = create_synthetic_audio_file ,
@@ -1200,9 +1202,10 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
12001202
12011203 dataset = StreamingDataset (input_dir = str (tmpdir ))
12021204 sample = dataset [0 ]
1203- tensor = torchaudio .load (sample )
1204- assert tensor [0 ].shape == torch .Size ([1 , 16000 ])
1205- assert tensor [1 ] == 16000
1205+ data , sample_rate = sf .read (sample )
1206+ tensor = torch .from_numpy (data ).unsqueeze (0 )
1207+ assert tensor .shape == torch .Size ([1 , 16000 ])
1208+ assert sample_rate == 16000
12061209
12071210
12081211def test_is_path_valid_in_studio (monkeypatch , tmpdir ):
0 commit comments