Skip to content

Commit 00c56ee

Browse files
committed
fix(tools): load audio close file after reading
1 parent c3948c8 commit 00c56ee

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

tools/audio/av.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,36 +42,53 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
4242

4343
def load_audio(
4444
file: Union[str, BytesIO, Path],
45-
sr: Optional[int]=None,
46-
format: Optional[str]=None,
47-
mono=True
45+
sr: Optional[int] = None,
46+
format: Optional[str] = None,
47+
mono=True,
4848
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
4949
"""
5050
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39
5151
"""
52-
if (isinstance(file, str) and not Path(file).exists()) or (isinstance(file, Path) and not file.exists()):
52+
if (isinstance(file, str) and not Path(file).exists()) or (
53+
isinstance(file, Path) and not file.exists()
54+
):
5355
raise FileNotFoundError(f"File not found: {file}")
5456
rate = 0
5557

5658
container = av.open(file, format=format)
5759
audio_stream = next(s for s in container.streams if s.type == "audio")
5860
channels = 1 if audio_stream.layout == "mono" else 2
5961
container.seek(0)
60-
resampler = AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) if sr is not None else None
62+
resampler = (
63+
AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr)
64+
if sr is not None
65+
else None
66+
)
6167

6268
# Estimated maximum total number of samples to pre-allocate the array
6369
# AV stores length in microseconds by default
64-
estimated_total_samples = int(container.duration * sr // 1_000_000) if sr is not None else 48000
65-
decoded_audio = np.zeros(estimated_total_samples + 1 if channels == 1 else (channels, estimated_total_samples + 1), dtype=np.float32)
70+
estimated_total_samples = (
71+
int(container.duration * sr // 1_000_000) if sr is not None else 48000
72+
)
73+
decoded_audio = np.zeros(
74+
(
75+
estimated_total_samples + 1
76+
if channels == 1
77+
else (channels, estimated_total_samples + 1)
78+
),
79+
dtype=np.float32,
80+
)
6681

6782
offset = 0
6883

6984
def process_packet(packet: List[AudioFrame]):
7085
frames_data = []
7186
rate = 0
7287
for frame in packet:
73-
frame.pts = None # 清除时间戳,避免重新采样问题
74-
resampled_frames = resampler.resample(frame) if resampler is not None else [frame]
88+
# frame.pts = None # 清除时间戳,避免重新采样问题
89+
resampled_frames = (
90+
resampler.resample(frame) if resampler is not None else [frame]
91+
)
7592
for resampled_frame in resampled_frames:
7693
frame_data = resampled_frame.to_ndarray()
7794
rate = resampled_frame.rate
@@ -83,16 +100,21 @@ def frame_iter(container):
83100
yield p.decode()
84101

85102
for r, frames_data in map(process_packet, frame_iter(container)):
86-
if not rate: rate = r
103+
if not rate:
104+
rate = r
87105
for frame_data in frames_data:
88106
end_index = offset + len(frame_data[0])
89107

90108
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
91109
if end_index > decoded_audio.shape[1]:
92-
decoded_audio = np.resize(decoded_audio, (decoded_audio.shape[0], end_index*4))
110+
decoded_audio = np.resize(
111+
decoded_audio, (decoded_audio.shape[0], end_index * 4)
112+
)
93113

94114
np.copyto(decoded_audio[..., offset:end_index], frame_data)
95115
offset += len(frame_data[0])
116+
117+
container.close()
96118

97119
# Truncate the array to the actual size
98120
decoded_audio = decoded_audio[..., :offset]

0 commit comments

Comments
 (0)