@@ -8,16 +8,16 @@ namespace facebook::torchcodec {
8
8
9
9
namespace {
10
10
11
- torch::Tensor validateWf (torch::Tensor wf ) {
11
+ torch::Tensor validateSamples (torch::Tensor samples ) {
12
12
TORCH_CHECK (
13
- wf .dtype () == torch::kFloat32 ,
14
- " waveform must have float32 dtype, got " ,
15
- wf .dtype ());
16
- TORCH_CHECK (wf .dim () == 2 , " waveform must have 2 dimensions, got " , wf .dim ());
13
+ samples .dtype () == torch::kFloat32 ,
14
+ " samples must have float32 dtype, got " ,
15
+ samples .dtype ());
16
+ TORCH_CHECK (samples .dim () == 2 , " samples must have 2 dimensions, got " , samples .dim ());
17
17
18
18
// We enforce this, but if we get user reports we should investigate whether
19
19
// that's actually needed.
20
- int numChannels = static_cast <int >(wf .sizes ()[0 ]);
20
+ int numChannels = static_cast <int >(samples .sizes ()[0 ]);
21
21
TORCH_CHECK (
22
22
numChannels <= AV_NUM_DATA_POINTERS,
23
23
" Trying to encode " ,
@@ -26,7 +26,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
26
26
AV_NUM_DATA_POINTERS,
27
27
" channels per frame." );
28
28
29
- return wf .contiguous ();
29
+ return samples .contiguous ();
30
30
}
31
31
32
32
void validateSampleRate (const AVCodec& avCodec, int sampleRate) {
@@ -71,7 +71,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
71
71
72
72
AVSampleFormat findBestOutputSampleFormat (const AVCodec& avCodec) {
73
73
// Find a sample format that the encoder supports. We prefer using FLT[P],
74
- // since this is the format of the input waveform . If FLTP isn't supported
74
+ // since this is the format of the input samples . If FLTP isn't supported
75
75
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
76
76
// into the format with the highest resolution.
77
77
if (avCodec.sample_fmts == nullptr ) {
@@ -98,12 +98,12 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
98
98
AudioEncoder::~AudioEncoder () {}
99
99
100
100
AudioEncoder::AudioEncoder (
101
- const torch::Tensor wf ,
101
+ const torch::Tensor samples ,
102
102
int sampleRate,
103
103
std::string_view fileName,
104
104
std::optional<int64_t > bitRate,
105
105
std::optional<int64_t > numChannels)
106
- : wf_(validateWf(wf )) {
106
+ : samples_(validateSamples(samples )) {
107
107
setFFmpegLogLevel ();
108
108
AVFormatContext* avFormatContext = nullptr ;
109
109
int status = avformat_alloc_output_context2 (
@@ -130,13 +130,13 @@ AudioEncoder::AudioEncoder(
130
130
}
131
131
132
132
AudioEncoder::AudioEncoder (
133
- const torch::Tensor wf ,
133
+ const torch::Tensor samples ,
134
134
int sampleRate,
135
135
std::string_view formatName,
136
136
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
137
137
std::optional<int64_t > bitRate,
138
138
std::optional<int64_t > numChannels)
139
- : wf_(validateWf(wf )), avioContextHolder_(std::move(avioContextHolder)) {
139
+ : samples_(validateSamples(samples )), avioContextHolder_(std::move(avioContextHolder)) {
140
140
setFFmpegLogLevel ();
141
141
AVFormatContext* avFormatContext = nullptr ;
142
142
int status = avformat_alloc_output_context2 (
@@ -177,7 +177,7 @@ void AudioEncoder::initializeEncoder(
177
177
// well when "-b:a" isn't specified.
178
178
avCodecContext_->bit_rate = bitRate.value_or (0 );
179
179
180
- desiredNumChannels_ = static_cast <int >(numChannels.value_or (wf_ .sizes ()[0 ]));
180
+ desiredNumChannels_ = static_cast <int >(numChannels.value_or (samples_ .sizes ()[0 ]));
181
181
validateNumChannels (*avCodec, desiredNumChannels_);
182
182
// The avCodecContext layout defines the layout of the encoded output, it's
183
183
// not related to the input sampes.
@@ -186,11 +186,13 @@ void AudioEncoder::initializeEncoder(
186
186
validateSampleRate (*avCodec, sampleRate);
187
187
avCodecContext_->sample_rate = sampleRate;
188
188
189
- // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
190
- // may need to convert the wf into a supported output sample format, which is
189
+ // Input samples are expected to be FLTP. Not all encoders support FLTP, so we
190
+ // may need to convert the samples into a supported output sample format, which is
191
191
// what the `.sample_fmt` defines.
192
192
avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
193
193
194
+ setDefaultChannelLayout (avCodecContext_, static_cast <int >(samples_.sizes ()[0 ]));
195
+
194
196
int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
195
197
TORCH_CHECK (
196
198
status == AVSUCCESS,
@@ -237,7 +239,7 @@ void AudioEncoder::encode() {
237
239
avFrame->pts = 0 ;
238
240
// We set the channel layout of the frame to the default layout corresponding
239
241
// to the input samples' number of channels
240
- setDefaultChannelLayout (avFrame, static_cast <int >(wf_ .sizes ()[0 ]));
242
+ setDefaultChannelLayout (avFrame, static_cast <int >(samples_ .sizes ()[0 ]));
241
243
242
244
auto status = av_frame_get_buffer (avFrame.get (), 0 );
243
245
TORCH_CHECK (
@@ -247,10 +249,10 @@ void AudioEncoder::encode() {
247
249
248
250
AutoAVPacket autoAVPacket;
249
251
250
- uint8_t * pwf = static_cast <uint8_t *>(wf_ .data_ptr ());
251
- int numSamples = static_cast <int >(wf_ .sizes ()[1 ]); // per channel
252
+ uint8_t * psamples = static_cast <uint8_t *>(samples_ .data_ptr ());
253
+ int numSamples = static_cast <int >(samples_ .sizes ()[1 ]); // per channel
252
254
int numEncodedSamples = 0 ; // per channel
253
- int numBytesPerSample = static_cast <int >(wf_ .element_size ());
255
+ int numBytesPerSample = static_cast <int >(samples_ .element_size ());
254
256
int numBytesPerChannel = numSamples * numBytesPerSample;
255
257
256
258
status = avformat_write_header (avFormatContext_.get (), nullptr );
@@ -270,11 +272,11 @@ void AudioEncoder::encode() {
270
272
std::min (numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
271
273
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
272
274
273
- for (int ch = 0 ; ch < wf_ .sizes ()[0 ]; ch++) {
275
+ for (int ch = 0 ; ch < samples_ .sizes ()[0 ]; ch++) {
274
276
std::memcpy (
275
- avFrame->data [ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
277
+ avFrame->data [ch], psamples + ch * numBytesPerChannel, numBytesToEncode);
276
278
}
277
- pwf += numBytesToEncode;
279
+ psamples += numBytesToEncode;
278
280
279
281
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
280
282
// that the frame buffers are allocated to a big enough size. Here, we reset
0 commit comments