Skip to content

Commit 2d76a7b

Browse files
committed
Add validation for num_channels
1 parent aad9c7d commit 2d76a7b

File tree

5 files changed

+72
-40
lines changed

5 files changed

+72
-40
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,6 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
5555
supportedRates.str());
5656
}
5757

58-
void print_supported_channel_layouts(const AVCodec *codec) {
59-
if (!codec->ch_layouts) {
60-
printf("No specific channel layouts supported by this encoder.\n");
61-
return;
62-
}
63-
const AVChannelLayout *layout = codec->ch_layouts;
64-
while (layout->order != AV_CHANNEL_ORDER_UNSPEC) {
65-
char layout_name[256];
66-
av_channel_layout_describe(layout, layout_name, sizeof(layout_name));
67-
printf("Supported channel layout: %s\n", layout_name);
68-
layout++;
69-
}
70-
}
71-
7258
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
7359
AV_SAMPLE_FMT_FLTP,
7460
AV_SAMPLE_FMT_FLT,
@@ -173,13 +159,12 @@ AudioEncoder::AudioEncoder(
173159
void AudioEncoder::initializeEncoder(
174160
int sampleRate,
175161
std::optional<int64_t> bitRate,
176-
[[maybe_unused]] std::optional<int64_t> numChannels) {
162+
std::optional<int64_t> numChannels) {
177163
// We use the AVFormatContext's default codec for that
178164
// specific format/container.
179165
const AVCodec* avCodec =
180166
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
181167
TORCH_CHECK(avCodec != nullptr, "Codec not found");
182-
print_supported_channel_layouts(avCodec);
183168

184169
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
185170
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
@@ -193,6 +178,7 @@ void AudioEncoder::initializeEncoder(
193178
avCodecContext_->bit_rate = bitRate.value_or(0);
194179

195180
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
181+
validateNumChannels(*avCodec, desiredNumChannels_);
196182

197183
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);
198184

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,56 @@ void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) {
100100
#endif
101101
}
102102

103-
// void setChannelLayout(
104-
// UniqueAVFrame& dstAVFrame,
105-
// const UniqueAVCodecContext& avCodecContext) {
106-
// #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
107-
// auto status = av_channel_layout_copy(
108-
// &dstAVFrame->ch_layout, &avCodecContext->ch_layout);
109-
// TORCH_CHECK(
110-
// status == AVSUCCESS,
111-
// "Couldn't copy channel layout to avFrame: ",
112-
// getFFMPEGErrorStringFromErrorCode(status));
113-
// #else
114-
// dstAVFrame->channel_layout = avCodecContext->channel_layout;
115-
// dstAVFrame->channels = avCodecContext->channels;
116-
117-
// #endif
118-
// }
103+
void validateNumChannels(const AVCodec& avCodec, int numChannels) {
104+
#if LIBAVFILTER_VERSION_MAJOR > 8 // FFmpeg > 5
105+
if (avCodec.ch_layouts == nullptr) {
106+
// If we can't validate, we must assume it'll be fine. If not, FFmpeg will
107+
// eventually raise.
108+
return;
109+
}
110+
for (auto i = 0; avCodec.ch_layouts[i].order != AV_CHANNEL_ORDER_UNSPEC;
111+
++i) {
112+
if (numChannels == avCodec.ch_layouts[i].nb_channels) {
113+
return;
114+
}
115+
}
116+
std::stringstream supportedNumChannels;
117+
for (auto i = 0; avCodec.ch_layouts[i].order != AV_CHANNEL_ORDER_UNSPEC;
118+
++i) {
119+
if (i > 0) {
120+
supportedNumChannels << ", ";
121+
}
122+
supportedNumChannels << avCodec.ch_layouts[i].nb_channels;
123+
}
124+
#else
125+
if (avCodec.channel_layouts == nullptr) {
126+
// can't validate, same as above.
127+
return;
128+
}
129+
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
130+
if (numChannels ==
131+
av_get_channel_layout_nb_channels(avCodec.channel_layouts[i])) {
132+
return;
133+
}
134+
}
135+
std::stringstream supportedNumChannels;
136+
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
137+
if (i > 0) {
138+
supportedNumChannels << ", ";
139+
}
140+
supportedNumChannels << av_get_channel_layout_nb_channels(
141+
avCodec.channel_layouts[i]);
142+
}
143+
#endif
144+
TORCH_CHECK(
145+
false,
146+
"Desired number of channels (",
147+
numChannels,
148+
") is not supported by the ",
149+
"encoder. Supported number of channels are: ",
150+
supportedNumChannels.str(),
151+
".");
152+
}
119153

120154
namespace {
121155
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ void setDefaultChannelLayout(
153153

154154
void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels);
155155

156-
// void setChannelLayout(
157-
// UniqueAVFrame& dstAVFrame,
158-
// const UniqueAVCodecContext& avCodecContext);
156+
void validateNumChannels(const AVCodec& avCodec, int numChannels);
159157

160158
void setChannelLayout(
161159
UniqueAVFrame& dstAVFrame,

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,27 @@ def to_file(
3131
dest: Union[str, Path],
3232
*,
3333
bit_rate: Optional[int] = None,
34+
num_channels: Optional[int] = None,
3435
) -> None:
3536
_core.encode_audio_to_file(
3637
wf=self._samples,
3738
sample_rate=self._sample_rate,
3839
filename=dest,
3940
bit_rate=bit_rate,
41+
num_channels=num_channels,
4042
)
4143

4244
def to_tensor(
4345
self,
4446
format: str,
4547
*,
4648
bit_rate: Optional[int] = None,
49+
num_channels: Optional[int] = None,
4750
) -> Tensor:
4851
return _core.encode_audio_to_tensor(
4952
wf=self._samples,
5053
sample_rate=self._sample_rate,
5154
format=format,
5255
bit_rate=bit_rate,
56+
num_channels=num_channels,
5357
)

test/test_ops.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import io
88
import os
9+
import re
910
from functools import partial
1011

1112
os.environ["TORCH_LOGS"] = "output_code"
@@ -1158,10 +1159,19 @@ def test_bad_input(self, tmp_path):
11581159
wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter"
11591160
)
11601161

1161-
encode_audio_to_file(
1162-
wf=torch.rand(2, 10), sample_rate=16_000, filename="ok.mp3", num_channels=8
1163-
)
1164-
1162+
for num_channels in (0, 3):
1163+
with pytest.raises(
1164+
RuntimeError,
1165+
match=re.escape(
1166+
f"Desired number of channels ({num_channels}) is not supported"
1167+
),
1168+
):
1169+
encode_audio_to_file(
1170+
wf=torch.rand(2, 10),
1171+
sample_rate=16_000,
1172+
filename="ok.mp3",
1173+
num_channels=num_channels,
1174+
)
11651175

11661176
@pytest.mark.parametrize(
11671177
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
@@ -1335,7 +1345,7 @@ def test_contiguity(self):
13351345
def test_num_channels(
13361346
self, num_channels_input, num_channels_output, encode_method, tmp_path
13371347
):
1338-
# We just check that the num_channels parmameter is respected.
1348+
# We just check that the num_channels parameter is respected.
13391349
# Correctness is checked in other tests (like test_against_cli())
13401350

13411351
sample_rate = 16_000

0 commit comments

Comments
 (0)