Skip to content

Commit 6d3ad1a

Browse files
Dan-Floresdanielflores3
and
danielflores3
authored
Rename 'wf' to 'samples' in AudioEncoder (#701)
Co-authored-by: danielflores3 <[email protected]>
1 parent ba44fdb commit 6d3ad1a

File tree

6 files changed

+50
-43
lines changed

6 files changed

+50
-43
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@ namespace facebook::torchcodec {
88

99
namespace {
1010

11-
torch::Tensor validateWf(torch::Tensor wf) {
11+
torch::Tensor validateSamples(torch::Tensor samples) {
1212
TORCH_CHECK(
13-
wf.dtype() == torch::kFloat32,
14-
"waveform must have float32 dtype, got ",
15-
wf.dtype());
16-
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
13+
samples.dtype() == torch::kFloat32,
14+
"samples must have float32 dtype, got ",
15+
samples.dtype());
16+
TORCH_CHECK(
17+
samples.dim() == 2,
18+
"samples must have 2 dimensions, got ",
19+
samples.dim());
1720

1821
// We enforce this, but if we get user reports we should investigate whether
1922
// that's actually needed.
20-
int numChannels = static_cast<int>(wf.sizes()[0]);
23+
int numChannels = static_cast<int>(samples.sizes()[0]);
2124
TORCH_CHECK(
2225
numChannels <= AV_NUM_DATA_POINTERS,
2326
"Trying to encode ",
@@ -26,7 +29,7 @@ torch::Tensor validateWf(torch::Tensor wf) {
2629
AV_NUM_DATA_POINTERS,
2730
" channels per frame.");
2831

29-
return wf.contiguous();
32+
return samples.contiguous();
3033
}
3134

3235
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
@@ -71,7 +74,7 @@ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
7174

7275
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
7376
// Find a sample format that the encoder supports. We prefer using FLT[P],
74-
// since this is the format of the input waveform. If FLTP isn't supported
77+
// since this is the format of the input samples. If FLTP isn't supported
7578
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
7679
// into the format with the highest resolution.
7780
if (avCodec.sample_fmts == nullptr) {
@@ -98,11 +101,11 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
98101
AudioEncoder::~AudioEncoder() {}
99102

100103
AudioEncoder::AudioEncoder(
101-
const torch::Tensor wf,
104+
const torch::Tensor samples,
102105
int sampleRate,
103106
std::string_view fileName,
104107
const AudioStreamOptions& audioStreamOptions)
105-
: wf_(validateWf(wf)) {
108+
: samples_(validateSamples(samples)) {
106109
setFFmpegLogLevel();
107110
AVFormatContext* avFormatContext = nullptr;
108111
int status = avformat_alloc_output_context2(
@@ -129,12 +132,13 @@ AudioEncoder::AudioEncoder(
129132
}
130133

131134
AudioEncoder::AudioEncoder(
132-
const torch::Tensor wf,
135+
const torch::Tensor samples,
133136
int sampleRate,
134137
std::string_view formatName,
135138
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
136139
const AudioStreamOptions& audioStreamOptions)
137-
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
140+
: samples_(validateSamples(samples)),
141+
avioContextHolder_(std::move(avioContextHolder)) {
138142
setFFmpegLogLevel();
139143
AVFormatContext* avFormatContext = nullptr;
140144
int status = avformat_alloc_output_context2(
@@ -176,8 +180,8 @@ void AudioEncoder::initializeEncoder(
176180
// well when "-b:a" isn't specified.
177181
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
178182

179-
outNumChannels_ =
180-
static_cast<int>(audioStreamOptions.numChannels.value_or(wf_.sizes()[0]));
183+
outNumChannels_ = static_cast<int>(
184+
audioStreamOptions.numChannels.value_or(samples_.sizes()[0]));
181185
validateNumChannels(*avCodec, outNumChannels_);
182186
// The avCodecContext layout defines the layout of the encoded output, it's
183187
// not related to the input sampes.
@@ -186,9 +190,9 @@ void AudioEncoder::initializeEncoder(
186190
validateSampleRate(*avCodec, sampleRate);
187191
avCodecContext_->sample_rate = sampleRate;
188192

189-
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
190-
// may need to convert the wf into a supported output sample format, which is
191-
// what the `.sample_fmt` defines.
193+
// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
194+
// may need to convert the samples into a supported output sample format,
195+
// which is what the `.sample_fmt` defines.
192196
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
193197

194198
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
@@ -237,7 +241,7 @@ void AudioEncoder::encode() {
237241
avFrame->pts = 0;
238242
// We set the channel layout of the frame to the default layout corresponding
239243
// to the input samples' number of channels
240-
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
244+
setDefaultChannelLayout(avFrame, static_cast<int>(samples_.sizes()[0]));
241245

242246
auto status = av_frame_get_buffer(avFrame.get(), 0);
243247
TORCH_CHECK(
@@ -247,10 +251,10 @@ void AudioEncoder::encode() {
247251

248252
AutoAVPacket autoAVPacket;
249253

250-
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
251-
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
254+
uint8_t* psamples = static_cast<uint8_t*>(samples_.data_ptr());
255+
int numSamples = static_cast<int>(samples_.sizes()[1]); // per channel
252256
int numEncodedSamples = 0; // per channel
253-
int numBytesPerSample = static_cast<int>(wf_.element_size());
257+
int numBytesPerSample = static_cast<int>(samples_.element_size());
254258
int numBytesPerChannel = numSamples * numBytesPerSample;
255259

256260
status = avformat_write_header(avFormatContext_.get(), nullptr);
@@ -270,11 +274,13 @@ void AudioEncoder::encode() {
270274
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
271275
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
272276

273-
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
277+
for (int ch = 0; ch < samples_.sizes()[0]; ch++) {
274278
std::memcpy(
275-
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
279+
avFrame->data[ch],
280+
psamples + ch * numBytesPerChannel,
281+
numBytesToEncode);
276282
}
277-
pwf += numBytesToEncode;
283+
psamples += numBytesToEncode;
278284

279285
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
280286
// that the frame buffers are allocated to a big enough size. Here, we reset

src/torchcodec/_core/Encoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ class AudioEncoder {
1515
// Passing 44_100 could result in output being 44000 if only 44000 is
1616
// supported.
1717
AudioEncoder(
18-
const torch::Tensor wf,
18+
const torch::Tensor samples,
1919
// TODO-ENCODING: update this comment when we support an output sample
2020
// rate. This will become the input sample rate.
2121
// The *output* sample rate. We can't really decide for the user what it
22-
// should be. Particularly, the sample rate of the input waveform should
22+
// should be. Particularly, the sample rate of the input samples should
2323
// match this, and that's up to the user. If sample rates don't match,
2424
// encoding will still work but audio will be distorted.
2525
int sampleRate,
2626
std::string_view fileName,
2727
const AudioStreamOptions& audioStreamOptions);
2828
AudioEncoder(
29-
const torch::Tensor wf,
29+
const torch::Tensor samples,
3030
int sampleRate,
3131
std::string_view formatName,
3232
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
@@ -51,7 +51,7 @@ class AudioEncoder {
5151

5252
int outNumChannels_ = -1;
5353

54-
const torch::Tensor wf_;
54+
const torch::Tensor samples_;
5555

5656
// Stores the AVIOContext for the output tensor buffer.
5757
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
32+
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
3333
m.def(
34-
"encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
34+
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
3535
m.def(
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
@@ -388,7 +388,7 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
388388
}
389389

390390
void encode_audio_to_file(
391-
const at::Tensor wf,
391+
const at::Tensor samples,
392392
int64_t sample_rate,
393393
std::string_view file_name,
394394
std::optional<int64_t> bit_rate = std::nullopt,
@@ -399,12 +399,12 @@ void encode_audio_to_file(
399399
audioStreamOptions.bitRate = bit_rate;
400400
audioStreamOptions.numChannels = num_channels;
401401
AudioEncoder(
402-
wf, validateSampleRate(sample_rate), file_name, audioStreamOptions)
402+
samples, validateSampleRate(sample_rate), file_name, audioStreamOptions)
403403
.encode();
404404
}
405405

406406
at::Tensor encode_audio_to_tensor(
407-
const at::Tensor wf,
407+
const at::Tensor samples,
408408
int64_t sample_rate,
409409
std::string_view format,
410410
std::optional<int64_t> bit_rate = std::nullopt,
@@ -416,7 +416,7 @@ at::Tensor encode_audio_to_tensor(
416416
audioStreamOptions.bitRate = bit_rate;
417417
audioStreamOptions.numChannels = num_channels;
418418
return AudioEncoder(
419-
wf,
419+
samples,
420420
validateSampleRate(sample_rate),
421421
format,
422422
std::move(avioContextHolder),

src/torchcodec/_core/ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,9 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
161161
return torch.empty([], dtype=torch.long)
162162

163163

164-
# TODO-ENCODING: rename wf to samples
165164
@register_fake("torchcodec_ns::encode_audio_to_file")
166165
def encode_audio_to_file_abstract(
167-
wf: torch.Tensor,
166+
samples: torch.Tensor,
168167
sample_rate: int,
169168
filename: str,
170169
bit_rate: Optional[int] = None,
@@ -175,7 +174,7 @@ def encode_audio_to_file_abstract(
175174

176175
@register_fake("torchcodec_ns::encode_audio_to_tensor")
177176
def encode_audio_to_tensor_abstract(
178-
wf: torch.Tensor,
177+
samples: torch.Tensor,
179178
sample_rate: int,
180179
format: str,
181180
bit_rate: Optional[int] = None,

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def to_file(
3434
num_channels: Optional[int] = None,
3535
) -> None:
3636
_core.encode_audio_to_file(
37-
wf=self._samples,
37+
samples=self._samples,
3838
sample_rate=self._sample_rate,
3939
filename=dest,
4040
bit_rate=bit_rate,
@@ -49,7 +49,7 @@ def to_tensor(
4949
num_channels: Optional[int] = None,
5050
) -> Tensor:
5151
return _core.encode_audio_to_tensor(
52-
wf=self._samples,
52+
samples=self._samples,
5353
sample_rate=self._sample_rate,
5454
format=format,
5555
bit_rate=bit_rate,

test/test_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,22 +1101,24 @@ def test_bad_input(self, tmp_path):
11011101

11021102
with pytest.raises(RuntimeError, match="must have float32 dtype, got int"):
11031103
encode_audio_to_file(
1104-
wf=torch.arange(10, dtype=torch.int),
1104+
samples=torch.arange(10, dtype=torch.int),
11051105
sample_rate=10,
11061106
filename=valid_output_file,
11071107
)
11081108
with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"):
11091109
encode_audio_to_file(
1110-
wf=torch.rand(3), sample_rate=10, filename=valid_output_file
1110+
samples=torch.rand(3), sample_rate=10, filename=valid_output_file
11111111
)
11121112

11131113
with pytest.raises(RuntimeError, match="No such file or directory"):
11141114
encode_audio_to_file(
1115-
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
1115+
samples=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
11161116
)
11171117
with pytest.raises(RuntimeError, match="check the desired extension"):
11181118
encode_audio_to_file(
1119-
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
1119+
samples=torch.rand(2, 10),
1120+
sample_rate=10,
1121+
filename="./file.bad_extension",
11201122
)
11211123

11221124

0 commit comments

Comments
 (0)