Skip to content

Commit 3056f40

Browse files
authored
Refactor audio sample conversion in encoder (#704)
1 parent 6d3ad1a commit 3056f40

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,13 @@ void AudioEncoder::encode() {
288288
// encoded frame would contain more samples than necessary and our results
289289
// wouldn't match the ffmpeg CLI.
290290
avFrame->nb_samples = numSamplesToEncode;
291-
encodeInnerLoop(autoAVPacket, avFrame);
292291

293-
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
292+
UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame);
293+
encodeInnerLoop(autoAVPacket, convertedAVFrame);
294+
294295
numEncodedSamples += numSamplesToEncode;
296+
// TODO-ENCODING set frame pts correctly, and test against it.
297+
// avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
295298
}
296299
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
297300

@@ -304,42 +307,43 @@ void AudioEncoder::encode() {
304307
getFFMPEGErrorStringFromErrorCode(status));
305308
}
306309

307-
void AudioEncoder::encodeInnerLoop(
308-
AutoAVPacket& autoAVPacket,
309-
const UniqueAVFrame& srcAVFrame) {
310-
bool mustConvert =
311-
(srcAVFrame != nullptr &&
312-
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP ||
313-
getNumChannels(srcAVFrame) != outNumChannels_));
314-
315-
UniqueAVFrame convertedAVFrame;
316-
if (mustConvert) {
317-
if (!swrContext_) {
318-
swrContext_.reset(createSwrContext(
319-
AV_SAMPLE_FMT_FLTP,
320-
avCodecContext_->sample_fmt,
321-
srcAVFrame->sample_rate, // No sample rate conversion
322-
srcAVFrame->sample_rate,
323-
srcAVFrame,
324-
outNumChannels_));
325-
}
326-
convertedAVFrame = convertAudioAVFrameSamples(
327-
swrContext_,
328-
srcAVFrame,
310+
UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
311+
if (static_cast<AVSampleFormat>(avFrame->format) ==
312+
avCodecContext_->sample_fmt &&
313+
getNumChannels(avFrame) == outNumChannels_) {
314+
// Note: the clone references the same underlying data, it's a cheap copy.
315+
return UniqueAVFrame(av_frame_clone(avFrame.get()));
316+
}
317+
318+
if (!swrContext_) {
319+
swrContext_.reset(createSwrContext(
320+
static_cast<AVSampleFormat>(avFrame->format),
329321
avCodecContext_->sample_fmt,
330-
srcAVFrame->sample_rate, // No sample rate conversion
331-
outNumChannels_);
332-
TORCH_CHECK(
333-
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
334-
"convertedAVFrame->nb_samples=",
335-
convertedAVFrame->nb_samples,
336-
" differs from ",
337-
"srcAVFrame->nb_samples=",
338-
srcAVFrame->nb_samples,
339-
"This is unexpected, please report on the TorchCodec bug tracker.");
322+
avFrame->sample_rate, // No sample rate conversion
323+
avFrame->sample_rate,
324+
avFrame,
325+
outNumChannels_));
340326
}
341-
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
327+
UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples(
328+
swrContext_,
329+
avFrame,
330+
avCodecContext_->sample_fmt,
331+
avFrame->sample_rate, // No sample rate conversion
332+
outNumChannels_);
333+
TORCH_CHECK(
334+
convertedAVFrame->nb_samples == avFrame->nb_samples,
335+
"convertedAVFrame->nb_samples=",
336+
convertedAVFrame->nb_samples,
337+
" differs from ",
338+
"avFrame->nb_samples=",
339+
avFrame->nb_samples,
340+
"This is unexpected, please report on the TorchCodec bug tracker.");
341+
return convertedAVFrame;
342+
}
342343

344+
void AudioEncoder::encodeInnerLoop(
345+
AutoAVPacket& autoAVPacket,
346+
const UniqueAVFrame& avFrame) {
343347
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
344348
TORCH_CHECK(
345349
status == AVSUCCESS,

src/torchcodec/_core/Encoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class AudioEncoder {
3838
void initializeEncoder(
3939
int sampleRate,
4040
const AudioStreamOptions& audioStreamOptions);
41+
UniqueAVFrame maybeConvertAVFrame(const UniqueAVFrame& avFrame);
4142
void encodeInnerLoop(
4243
AutoAVPacket& autoAVPacket,
4344
const UniqueAVFrame& srcAVFrame);

0 commit comments

Comments
 (0)