Skip to content

Commit 52098c9

Browse files
author
danielflores3
committedMay 22, 2025
Rename 'wf' to 'samples' in AudioEncoder
1 parent c45c9c6 commit 52098c9

File tree

3 files changed

+34
-32
lines changed

3 files changed

+34
-32
lines changed
 

‎src/torchcodec/_core/Encoder.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ namespace facebook::torchcodec {
88

99
namespace {
1010

11-
torch::Tensor validateWf(torch::Tensor wf) {
11+
torch::Tensor validateSamples(torch::Tensor samples) {
1212
TORCH_CHECK(
13-
wf.dtype() == torch::kFloat32,
14-
"waveform must have float32 dtype, got ",
15-
wf.dtype());
16-
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
13+
samples.dtype() == torch::kFloat32,
14+
"samples must have float32 dtype, got ",
15+
samples.dtype());
16+
TORCH_CHECK(samples.dim() == 2, "samples must have 2 dimensions, got ", samples.dim());
1717

1818
// We enforce this, but if we get user reports we should investigate whether
1919
// that's actually needed.
20-
int numChannels = static_cast<int>(wf.sizes()[0]);
20+
int numChannels = static_cast<int>(samples.sizes()[0]);
2121
TORCH_CHECK(
2222
numChannels <= AV_NUM_DATA_POINTERS,
2323
"Trying to encode ",
@@ -26,7 +26,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
2626
AV_NUM_DATA_POINTERS,
2727
" channels per frame.");
2828

29-
return wf.contiguous();
29+
return samples.contiguous();
3030
}
3131

3232
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
@@ -71,7 +71,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
7171

7272
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
7373
// Find a sample format that the encoder supports. We prefer using FLT[P],
74-
// since this is the format of the input waveform. If FLTP isn't supported
74+
// since this is the format of the input samples. If FLTP isn't supported
7575
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
7676
// into the format with the highest resolution.
7777
if (avCodec.sample_fmts == nullptr) {
@@ -98,12 +98,12 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
9898
AudioEncoder::~AudioEncoder() {}
9999

100100
AudioEncoder::AudioEncoder(
101-
const torch::Tensor wf,
101+
const torch::Tensor samples,
102102
int sampleRate,
103103
std::string_view fileName,
104104
std::optional<int64_t> bitRate,
105105
std::optional<int64_t> numChannels)
106-
: wf_(validateWf(wf)) {
106+
: samples_(validateSamples(samples)) {
107107
setFFmpegLogLevel();
108108
AVFormatContext* avFormatContext = nullptr;
109109
int status = avformat_alloc_output_context2(
@@ -130,13 +130,13 @@ AudioEncoder::AudioEncoder(
130130
}
131131

132132
AudioEncoder::AudioEncoder(
133-
const torch::Tensor wf,
133+
const torch::Tensor samples,
134134
int sampleRate,
135135
std::string_view formatName,
136136
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
137137
std::optional<int64_t> bitRate,
138138
std::optional<int64_t> numChannels)
139-
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
139+
: samples_(validateSamples(samples)), avioContextHolder_(std::move(avioContextHolder)) {
140140
setFFmpegLogLevel();
141141
AVFormatContext* avFormatContext = nullptr;
142142
int status = avformat_alloc_output_context2(
@@ -177,7 +177,7 @@ void AudioEncoder::initializeEncoder(
177177
// well when "-b:a" isn't specified.
178178
avCodecContext_->bit_rate = bitRate.value_or(0);
179179

180-
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
180+
desiredNumChannels_ = static_cast<int>(numChannels.value_or(samples_.sizes()[0]));
181181
validateNumChannels(*avCodec, desiredNumChannels_);
182182
// The avCodecContext layout defines the layout of the encoded output, it's
183183
// not related to the input sampes.
@@ -186,11 +186,13 @@ void AudioEncoder::initializeEncoder(
186186
validateSampleRate(*avCodec, sampleRate);
187187
avCodecContext_->sample_rate = sampleRate;
188188

189-
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
190-
// may need to convert the wf into a supported output sample format, which is
189+
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
190+
// may need to convert the samples into a supported output sample format, which is
191191
// what the `.sample_fmt` defines.
192192
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
193193

194+
setDefaultChannelLayout(avCodecContext_, static_cast<int>(samples_.sizes()[0]));
195+
194196
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
195197
TORCH_CHECK(
196198
status == AVSUCCESS,
@@ -237,7 +239,7 @@ void AudioEncoder::encode() {
237239
avFrame->pts = 0;
238240
// We set the channel layout of the frame to the default layout corresponding
239241
// to the input samples' number of channels
240-
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
242+
setDefaultChannelLayout(avFrame, static_cast<int>(samples_.sizes()[0]));
241243

242244
auto status = av_frame_get_buffer(avFrame.get(), 0);
243245
TORCH_CHECK(
@@ -247,10 +249,10 @@ void AudioEncoder::encode() {
247249

248250
AutoAVPacket autoAVPacket;
249251

250-
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
251-
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
252+
uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
253+
int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
252254
int numEncodedSamples = 0; // per channel
253-
int numBytesPerSample = static_cast<int>(wf_.element_size());
255+
int numBytesPerSample = static_cast<int>(samples_.element_size());
254256
int numBytesPerChannel = numSamples * numBytesPerSample;
255257

256258
status = avformat_write_header(avFormatContext_.get(), nullptr);
@@ -270,11 +272,11 @@ void AudioEncoder::encode() {
270272
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
271273
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
272274

273-
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
275+
for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
274276
std::memcpy(
275-
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
277+
avFrame->data[ch], psamples + ch * numBytesPerChannel, numBytesToEncode);
276278
}
277-
pwf += numBytesToEncode;
279+
psamples += numBytesToEncode;
278280

279281
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
280282
// that the frame buffers are allocated to a big enough size. Here, we reset

‎src/torchcodec/_core/Encoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ class AudioEncoder {
1717
// TODO-ENCODING: bundle the optional params like bitRate, numChannels, etc.
1818
// into an AudioStreamOptions struct, or similar.
1919
AudioEncoder(
20-
const torch::Tensor wf,
20+
const torch::Tensor samples,
2121
// The *output* sample rate. We can't really decide for the user what it
22-
// should be. Particularly, the sample rate of the input waveform should
22+
// should be. Particularly, the sample rate of the input samples should
2323
// match this, and that's up to the user. If sample rates don't match,
2424
// encoding will still work but audio will be distorted.
2525
int sampleRate,
2626
std::string_view fileName,
2727
std::optional<int64_t> bitRate = std::nullopt,
2828
std::optional<int64_t> numChannels = std::nullopt);
2929
AudioEncoder(
30-
const torch::Tensor wf,
30+
const torch::Tensor samples,
3131
int sampleRate,
3232
std::string_view formatName,
3333
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
@@ -54,7 +54,7 @@ class AudioEncoder {
5454
// see other TODO above.
5555
int desiredNumChannels_ = -1;
5656

57-
const torch::Tensor wf_;
57+
const torch::Tensor samples_;
5858

5959
// Stores the AVIOContext for the output tensor buffer.
6060
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;

‎src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
32+
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -388,25 +388,25 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
388388
}
389389

390390
void encode_audio_to_file(
391-
const at::Tensor wf,
391+
const at::Tensor samples,
392392
int64_t sample_rate,
393393
std::string_view file_name,
394394
std::optional<int64_t> bit_rate = std::nullopt,
395395
std::optional<int64_t> num_channels = std::nullopt) {
396396
AudioEncoder(
397-
wf, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
397+
samples, validateSampleRate(sample_rate), file_name, bit_rate, num_channels)
398398
.encode();
399399
}
400400

401401
at::Tensor encode_audio_to_tensor(
402-
const at::Tensor wf,
402+
const at::Tensor samples,
403403
int64_t sample_rate,
404404
std::string_view format,
405405
std::optional<int64_t> bit_rate = std::nullopt,
406406
std::optional<int64_t> num_channels = std::nullopt) {
407407
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
408408
return AudioEncoder(
409-
wf,
409+
samples,
410410
validateSampleRate(sample_rate),
411411
format,
412412
std::move(avioContextHolder),

0 commit comments

Comments
 (0)